From 8b2a3a3781a4a092d508bae50ba58646410d5c55 Mon Sep 17 00:00:00 2001 From: caffix Date: Fri, 20 Dec 2024 02:01:41 -0500 Subject: [PATCH] started to implement the neo4j repository edge methods --- cache/entity.go | 6 +- repository/neo4j/edge.go | 302 ++++++++++++++++++++++++++++++++ repository/neo4j/edge_tag.go | 297 +++++++++++++++++++++++++++++++ repository/neo4j/edge_test.go | 5 + repository/neo4j/entity.go | 2 +- repository/neo4j/entity_tag.go | 61 +++---- repository/neo4j/entity_test.go | 18 ++ repository/sqlrepo/edge.go | 4 +- repository/sqlrepo/entity.go | 2 +- 9 files changed, 654 insertions(+), 43 deletions(-) create mode 100644 repository/neo4j/edge.go create mode 100644 repository/neo4j/edge_tag.go create mode 100644 repository/neo4j/edge_test.go diff --git a/cache/entity.go b/cache/entity.go index 543225c..9758dc6 100644 --- a/cache/entity.go +++ b/cache/entity.go @@ -140,8 +140,10 @@ func (c *Cache) DeleteEntity(id string) error { return err } - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - _ = c.db.DeleteEntity(e[0].ID) + if ents, err := c.db.FindEntitiesByContent(entity.Asset, time.Time{}); err == nil && len(ents) > 0 { + for _, e := range ents { + _ = c.db.DeleteEntity(e.ID) + } } return nil diff --git a/repository/neo4j/edge.go b/repository/neo4j/edge.go new file mode 100644 index 0000000..9294e8a --- /dev/null +++ b/repository/neo4j/edge.go @@ -0,0 +1,302 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "context" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" + "gorm.io/gorm" +) + +// CreateEdge creates an edge between two entities in the database. +// The edge is established by creating a new Edge in the database, linking the two entities. +// Returns the created edge as a types.Edge or an error if the link creation fails. +func (neo *neoRepository) CreateEdge(edge *types.Edge) (*types.Edge, error) { + if edge == nil || edge.Relation == nil || edge.FromEntity == nil || + edge.FromEntity.Asset == nil || edge.ToEntity == nil || edge.ToEntity.Asset == nil { + return nil, errors.New("failed input validation checks") + } + + if !oam.ValidRelationship(edge.FromEntity.Asset.AssetType(), + edge.Relation.Label(), edge.Relation.RelationType(), edge.ToEntity.Asset.AssetType()) { + return &types.Edge{}, fmt.Errorf("%s -%s-> %s is not valid in the taxonomy", + edge.FromEntity.Asset.AssetType(), edge.Relation.Label(), edge.ToEntity.Asset.AssetType()) + } + + var updated time.Time + if edge.LastSeen.IsZero() { + updated = time.Now().UTC() + } else { + updated = edge.LastSeen.UTC() + } + // ensure that duplicate relationships are not entered into the database + if e, found := neo.isDuplicateEdge(edge, updated); found { + return e, nil + } + + fromEntityId, err := strconv.ParseUint(edge.FromEntity.ID, 10, 64) + if err != nil { + return nil, err + } + + toEntityId, err := strconv.ParseUint(edge.ToEntity.ID, 10, 64) + if err != nil { + return nil, err + } + + jsonContent, err := edge.Relation.JSON() + if err != nil { + return nil, err + } + + r := Edge{ + Type: string(edge.Relation.RelationType()), + Content: jsonContent, + FromEntityID: fromEntityId, + ToEntityID: toEntityId, + UpdatedAt: updated, + } + if edge.CreatedAt.IsZero() { + r.CreatedAt = time.Now().UTC() + } else { + r.CreatedAt = edge.CreatedAt.UTC() + } + + result := sql.db.Create(&r) + if err := result.Error; err != nil { + return nil, err + } + return toEdge(r), nil +} + +// isDuplicateEdge checks if the relationship between source and dest already exists. +func (neo *neoRepository) isDuplicateEdge(edge *types.Edge, updated time.Time) (*types.Edge, bool) { + var dup bool + var e *types.Edge + + if outs, err := neo.OutgoingEdges(edge.FromEntity, time.Time{}, edge.Relation.Label()); err == nil { + for _, out := range outs { + if edge.ToEntity.ID == out.ToEntity.ID && reflect.DeepEqual(edge.Relation, out.Relation) { + _ = neo.edgeSeen(out, updated) + + e, err = neo.FindEdgeById(out.ID) + if err != nil { + return nil, false + } + + dup = true + break + } + } + } + + return e, dup +} + +// edgeSeen updates the updated_at timestamp for the specified edge. +func (neo *neoRepository) edgeSeen(rel *types.Edge, updated time.Time) error { + id, err := strconv.ParseUint(rel.ID, 10, 64) + if err != nil { + return err + } + + jsonContent, err := rel.Relation.JSON() + if err != nil { + return err + } + + fromEntityId, err := strconv.ParseUint(rel.FromEntity.ID, 10, 64) + if err != nil { + return err + } + + toEntityId, err := strconv.ParseUint(rel.ToEntity.ID, 10, 64) + if err != nil { + return err + } + + r := Edge{ + ID: id, + Type: string(rel.Relation.RelationType()), + Content: jsonContent, + FromEntityID: fromEntityId, + ToEntityID: toEntityId, + CreatedAt: rel.CreatedAt, + UpdatedAt: updated, + } + + result := sql.db.Save(&r) + if err := result.Error; err != nil { + return err + } + return nil +} + +func (neo *neoRepository) FindEdgeById(id string) (*types.Edge, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH (from:Entity)-[r]->(to:Entity) WHERE r.elementId = $eid RETURN r, from.entity_id AS fid, to.entity_id AS tid", + map[string]interface{}{ + "eid": id, + }, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, errors.New("no edge was found") + } + + r, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](result.Records[0], "r") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the relationship is nil") + } + + fid, isnil, err := neo4jdb.GetRecordValue[string](result.Records[0], "fid") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the from entity ID is nil") + } + + tid, isnil, err := neo4jdb.GetRecordValue[string](result.Records[0], "tid") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the to entity ID is nil") + } + + edge, err := relationshipToEdge(r) + if err != nil { + return nil, err + } + edge.FromEntity = &types.Entity{ID: fid} + edge.ToEntity = &types.Entity{ID: tid} + return edge, err +} + +// IncomingEdges finds all edges pointing to the entity of the specified labels and last seen after the since parameter. +// If since.IsZero(), the parameter will be ignored. +// If no labels are specified, all incoming eges are returned. +func (neo *neoRepository) IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { + entityId, err := strconv.ParseInt(entity.ID, 10, 64) + if err != nil { + return nil, err + } + + var edges []Edge + var result *gorm.DB + if since.IsZero() { + result = sql.db.Where("to_entity_id = ?", entityId).Find(&edges) + } else { + result = sql.db.Where("to_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges) + } + if err := result.Error; err != nil { + return nil, err + } + + var results []Edge + if len(labels) > 0 { + for _, edge := range edges { + e := &edge + + if rel, err := e.Parse(); err == nil { + for _, label := range labels { + if label == rel.Label() { + results = append(results, edge) + break + } + } + } + } + } else { + results = edges + } + + if len(results) == 0 { + return nil, errors.New("zero edges found") + } + return toEdges(results), nil +} + +// OutgoingEdges finds all edges from the entity of the specified labels and last seen after the since parameter. +// If since.IsZero(), the parameter will be ignored. +// If no labels are specified, all outgoing edges are returned. +func (neo *neoRepository) OutgoingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { + entityId, err := strconv.ParseInt(entity.ID, 10, 64) + if err != nil { + return nil, err + } + + var edges []Edge + var result *gorm.DB + if since.IsZero() { + result = sql.db.Where("from_entity_id = ?", entityId).Find(&edges) + } else { + result = sql.db.Where("from_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges) + } + if err := result.Error; err != nil { + return nil, err + } + + var results []Edge + if len(labels) > 0 { + for _, edge := range edges { + e := &edge + + if rel, err := e.Parse(); err == nil { + for _, label := range labels { + if label == rel.Label() { + results = append(results, edge) + break + } + } + } + } + } else { + results = edges + } + + if len(results) == 0 { + return nil, errors.New("zero edges found") + } + return toEdges(results), nil +} + +// DeleteEdge removes an edge in the database by its ID. +// It takes a string representing the edge ID and removes the corresponding edge from the database. +// Returns an error if the edge is not found. +func (neo *neoRepository) DeleteEdge(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH ()-[r]->() WHERE r.elementId = $eid DELETE r", + map[string]interface{}{ + "eid": id, + }, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + + return err +} diff --git a/repository/neo4j/edge_tag.go b/repository/neo4j/edge_tag.go new file mode 100644 index 0000000..aeb177d --- /dev/null +++ b/repository/neo4j/edge_tag.go @@ -0,0 +1,297 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" +) + +// CreateEdgeTag creates a new edge tag in the database. +// It takes an EdgeTag as input and persists it in the database. +// The property is serialized to JSON and stored in the Content field of the EdgeTag struct. +// Returns the created edge tag as a types.EdgeTag or an error if the creation fails. +func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag) (*types.EdgeTag, error) { + var tag *types.EdgeTag + + if input == nil { + return nil, errors.New("the input edge tag is nil") + } + // ensure that duplicate entities are not entered into the database + if tags, err := neo.FindEdgeTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 { + t := tags[0] + + if input.Property.PropertyType() != t.Property.PropertyType() { + return nil, errors.New("the property type does not match the existing tag") + } + + qnode, err := queryNodeByPropertyKeyValue("p", "EdgeTag", t.Property) + if err != nil { + return nil, err + } + + t.LastSeen = time.Now() + props, err := edgeTagPropsMap(t) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH "+qnode+" SET p = $props RETURN p", + map[string]interface{}{"props": props}, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, errors.New("no records returned from the query") + } + + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "p") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the node is nil") + } + + if extracted, err := nodeToEdgeTag(node); err == nil && extracted != nil { + tag = extracted + } + } else { + if input.ID == "" { + input.ID = neo.uniqueEdgeTagID() + } + if input.CreatedAt.IsZero() { + input.CreatedAt = time.Now() + } + if input.LastSeen.IsZero() { + input.LastSeen = time.Now() + } + + props, err := edgeTagPropsMap(input) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + query := fmt.Sprintf("CREATE (p:EdgeTag:%s $props) RETURN p", input.Property.PropertyType()) + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, + map[string]interface{}{"props": props}, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, errors.New("no records returned from the query") + } + + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "p") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the node is nil") + } + + if t, err := nodeToEdgeTag(node); err == nil && t != nil { + tag = t + } + } + + if tag == nil { + return nil, errors.New("failed to create the edge tag") + } + return tag, nil +} + +// CreateEdgeProperty creates a new edge tag in the database. +// It takes an oam.Property as input and persists it in the database. +// The property is serialized to JSON and stored in the Content field of the EdgeTag struct. +// Returns the created edge tag as a types.EdgeTag or an error if the creation fails. +func (neo *neoRepository) CreateEdgeProperty(edge *types.Edge, prop oam.Property) (*types.EdgeTag, error) { + return neo.CreateEdgeTag(edge, &types.EdgeTag{Property: prop}) +} + +func (neo *neoRepository) uniqueEdgeTagID() string { + for { + id := uuid.New().String() + if _, err := neo.FindEdgeTagById(id); err != nil { + return id + } + } +} + +// FindEdgeTagById finds an edge tag in the database by the ID. +// It takes a string representing the edge tag ID and retrieves the corresponding tag from the database. +// Returns the discovered tag as a types.EdgeTag or an error if the asset is not found. +func (neo *neoRepository) FindEdgeTagById(id string) (*types.EdgeTag, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH (p:EdgeTag {tag_id: $tid}) RETURN p", + map[string]interface{}{"tid": id}, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, fmt.Errorf("the edge tag with ID %s was not found", id) + } + + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "p") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the node is nil") + } + return nodeToEdgeTag(node) +} + +// FindEdgeTagsByContent finds edge tags in the database that match the provided property data and updated_at after the since parameter. +// It takes an oam.Property as input and searches for edge tags with matching content in the database. +// If since.IsZero(), the parameter will be ignored. +// The property data is serialized to JSON and compared against the Content field of the EdgeTag struct. +// Returns a slice of matching edge tags as []*types.EdgeTag or an error if the search fails. +func (neo *neoRepository) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) { + qnode, err := queryNodeByPropertyKeyValue("p", "EdgeTag", prop) + if err != nil { + return nil, err + } + + query := "MATCH " + qnode + " RETURN p" + if !since.IsZero() { + query = fmt.Sprintf("MATCH %s WHERE p.updated_at >= localDateTime('%s') RETURN p", qnode, timeToNeo4jTime(since)) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, errors.New("no edge tags found") + } + + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "p") + if err != nil { + return nil, err + } + if isnil { + return nil, errors.New("the record value for the node is nil") + } + + tag, err := nodeToEdgeTag(node) + if err != nil { + return nil, err + } + return []*types.EdgeTag{tag}, nil +} + +// GetEdgeTags finds all tags for the edge with the specified names and last seen after the since parameter. +// If since.IsZero(), the parameter will be ignored. +// If no names are specified, all tags for the specified edge are returned. +func (neo *neoRepository) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error) { + query := fmt.Sprintf("MATCH (p:EdgeTag {edge_id: '%s'}) RETURN p", edge.ID) + if !since.IsZero() { + query = fmt.Sprintf("MATCH (p:EdgeTag {edge_id: '%s'}) WHERE p.updated_at >= localDateTime('%s') RETURN p", edge.ID, timeToNeo4jTime(since)) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { + return nil, err + } + if len(result.Records) == 0 { + return nil, errors.New("no edge tags found") + } + + var results []*types.EdgeTag + for _, record := range result.Records { + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](record, "p") + if err != nil { + continue + } + if isnil { + continue + } + + tag, err := nodeToEdgeTag(node) + if err != nil { + continue + } + + if len(names) > 0 { + var found bool + n := tag.Property.Name() + + for _, name := range names { + if name == n { + found = true + break + } + } + if !found { + continue + } + } + + results = append(results, tag) + } + + if len(results) == 0 { + return nil, errors.New("zero tags found") + } + return results, nil +} + +// DeleteEdgeTag removes an edge tag in the database by its ID. +// It takes a string representing the edge tag ID and removes the corresponding tag from the database. +// Returns an error if the tag is not found. +func (neo *neoRepository) DeleteEdgeTag(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH (n:EdgeTag {tag_id: $tid}) DETACH DELETE n", + map[string]interface{}{ + "tid": id, + }, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + + return err +} diff --git a/repository/neo4j/edge_test.go b/repository/neo4j/edge_test.go new file mode 100644 index 0000000..1f2e616 --- /dev/null +++ b/repository/neo4j/edge_test.go @@ -0,0 +1,5 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j diff --git a/repository/neo4j/entity.go b/repository/neo4j/entity.go index 54bb464..e025f30 100644 --- a/repository/neo4j/entity.go +++ b/repository/neo4j/entity.go @@ -27,7 +27,7 @@ func (neo *neoRepository) CreateEntity(input *types.Entity) (*types.Entity, erro return nil, errors.New("the input entity is nil") } // ensure that duplicate entities are not entered into the database - if entities, err := neo.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) == 1 { + if entities, err := neo.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) > 0 { e := entities[0] if input.Asset.AssetType() != e.Asset.AssetType() { diff --git a/repository/neo4j/entity_tag.go b/repository/neo4j/entity_tag.go index ad1e8f7..55e033d 100644 --- a/repository/neo4j/entity_tag.go +++ b/repository/neo4j/entity_tag.go @@ -8,7 +8,6 @@ import ( "context" "errors" "fmt" - "strconv" "time" "github.com/google/uuid" @@ -25,7 +24,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent var tag *types.EntityTag if input == nil { - return nil, errors.New("the input entity is nil") + return nil, errors.New("the input entity tag is nil") } // ensure that duplicate entities are not entered into the database if tags, err := neo.FindEntityTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 { @@ -239,49 +238,37 @@ func (neo *neoRepository) GetEntityTags(entity *types.Entity, since time.Time, n return nil, errors.New("no entity tags found") } - node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "p") - if err != nil { - return nil, err - } - if isnil { - return nil, errors.New("the record value for the node is nil") - } - - tag, err := nodeToEntityTag(node) - if err != nil { - return nil, err - } - return []*types.EntityTag{tag}, nil - var results []*types.EntityTag - for _, tag := range tags { - t := &tag + for _, record := range result.Records { + node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](record, "p") + if err != nil { + continue + } + if isnil { + continue + } - if prop, err := t.Parse(); err == nil { - found := true + tag, err := nodeToEntityTag(node) + if err != nil { + continue + } - if len(names) > 0 { - found = false - n := prop.Name() + if len(names) > 0 { + var found bool + n := tag.Property.Name() - for _, name := range names { - if name == n { - found = true - break - } + for _, name := range names { + if name == n { + found = true + break } } - - if found { - results = append(results, &types.EntityTag{ - ID: strconv.Itoa(int(t.ID)), - CreatedAt: t.CreatedAt.In(time.UTC).Local(), - LastSeen: t.UpdatedAt.In(time.UTC).Local(), - Property: prop, - Entity: entity, - }) + if !found { + continue } } + + results = append(results, tag) } if len(results) == 0 { diff --git a/repository/neo4j/entity_test.go b/repository/neo4j/entity_test.go index 8772338..8e5f727 100644 --- a/repository/neo4j/entity_test.go +++ b/repository/neo4j/entity_test.go @@ -146,3 +146,21 @@ func TestFindEntitiesByType(t *testing.T) { t.Errorf("Failed to return the correct number of entities") } } + +func TestDeleteEntity(t *testing.T) { + entity, err := store.CreateEntity(&types.Entity{ + Asset: &domain.FQDN{ + Name: "delete.entity", + }, + }) + assert.NoError(t, err) + + err = store.DeleteEntity(entity.ID) + assert.NoError(t, err) + + _, err := store.FindEntityById(entity.ID) + assert.Error(t, err) + + err = store.DeleteEntity(entity.ID) + assert.Error(t, err) +} diff --git a/repository/sqlrepo/edge.go b/repository/sqlrepo/edge.go index 52218ed..fea9e0d 100644 --- a/repository/sqlrepo/edge.go +++ b/repository/sqlrepo/edge.go @@ -38,7 +38,7 @@ func (sql *sqlRepository) CreateEdge(edge *types.Edge) (*types.Edge, error) { updated = edge.LastSeen.UTC() } // ensure that duplicate relationships are not entered into the database - if e, found := sql.isDuplicateEdge(edge); found { + if e, found := sql.isDuplicateEdge(edge, updated); found { return e, nil } @@ -85,7 +85,7 @@ func (sql *sqlRepository) isDuplicateEdge(edge *types.Edge, updated time.Time) ( if outs, err := sql.OutgoingEdges(edge.FromEntity, time.Time{}, edge.Relation.Label()); err == nil { for _, out := range outs { if edge.ToEntity.ID == out.ToEntity.ID && reflect.DeepEqual(edge.Relation, out.Relation) { - _ = sql.edgeSeen(out) + _ = sql.edgeSeen(out, updated) e, err = sql.FindEdgeById(out.ID) if err != nil { diff --git a/repository/sqlrepo/entity.go b/repository/sqlrepo/entity.go index ea65904..d6553bc 100644 --- a/repository/sqlrepo/entity.go +++ b/repository/sqlrepo/entity.go @@ -30,7 +30,7 @@ func (sql *sqlRepository) CreateEntity(input *types.Entity) (*types.Entity, erro } // ensure that duplicate entities are not entered into the database - if entities, err := sql.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) == 1 { + if entities, err := sql.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) > 0 { e := entities[0] if input.Asset.AssetType() == e.Asset.AssetType() {