diff --git a/cache/edge.go b/cache/edge.go index 3a6530a..389c230 100644 --- a/cache/edge.go +++ b/cache/edge.go @@ -32,7 +32,7 @@ func (c *Cache) CreateEdge(edge *types.Edge) (*types.Edge, error) { } if _, _, found := c.checkCacheEdgeTag(edge, "cache_create_edge"); !found { - _ = c.createCacheEdgeTag(edge, "cache_create_edge", time.Now()) + _ = c.createCacheEdgeTag(e, "cache_create_edge", time.Now()) c.appendToDBQueue(func() { s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) @@ -58,6 +58,14 @@ func (c *Cache) CreateEdge(edge *types.Edge) (*types.Edge, error) { return e, nil } +// FindEdgeById implements the Repository interface. +func (c *Cache) FindEdgeById(id string) (*types.Edge, error) { + c.Lock() + defer c.Unlock() + + return c.cache.FindEdgeById(id) +} + // IncomingEdges implements the Repository interface. func (c *Cache) IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { var dbquery bool @@ -150,7 +158,7 @@ func (c *Cache) OutgoingEdges(entity *types.Entity, since time.Time, labels ...s defer func() { done <- struct{}{} }() if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - dbedges, dberr = c.db.IncomingEdges(e[0], since, labels...) + dbedges, dberr = c.db.OutgoingEdges(e[0], since, labels...) for i, edge := range dbedges { if e, err := c.db.FindEntityById(edge.ToEntity.ID); err == nil && e != nil { @@ -189,7 +197,7 @@ func (c *Cache) OutgoingEdges(entity *types.Entity, since time.Time, labels ...s defer c.Unlock() } - return c.cache.IncomingEdges(entity, since, labels...) + return c.cache.OutgoingEdges(entity, since, labels...) } // DeleteEdge implements the Repository interface. @@ -197,11 +205,6 @@ func (c *Cache) DeleteEdge(id string) error { c.Lock() defer c.Unlock() - err := c.cache.DeleteEdge(id) - if err != nil { - return err - } - edge, err := c.cache.FindEdgeById(id) if err != nil { return nil @@ -217,6 +220,10 @@ func (c *Cache) DeleteEdge(id string) error { return nil } + if err := c.cache.DeleteEdge(id); err != nil { + return err + } + c.appendToDBQueue(func() { s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) if err != nil || len(s) != 1 { @@ -235,7 +242,7 @@ func (c *Cache) DeleteEdge(id string) error { var target *types.Edge for _, e := range edges { - if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { + if e.ToEntity.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { target = e break } diff --git a/cache/edge_test.go b/cache/edge_test.go new file mode 100644 index 0000000..0158309 --- /dev/null +++ b/cache/edge_test.go @@ -0,0 +1,462 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "os" + "reflect" + "testing" + "time" + + "github.com/caffix/stringset" + "github.com/owasp-amass/asset-db/types" + "github.com/owasp-amass/open-asset-model/domain" + "github.com/owasp-amass/open-asset-model/relation" + "github.com/stretchr/testify/assert" +) + +func TestCreateEdge(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + now := time.Now() + ctime := now.Add(-8 * time.Hour) + before := ctime.Add(-2 * time.Second) + after := ctime.Add(2 * time.Second) + + entity1, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "owasp.org"}, + }) + assert.NoError(t, err) + + entity2, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "www.owasp.org"}, + }) + assert.NoError(t, err) + + edge, err := c.CreateEdge(&types.Edge{ + CreatedAt: ctime, + LastSeen: ctime, + Relation: &relation.BasicDNSRelation{ + Name: "dns_record", + Header: relation.RRHeader{ + RRType: 5, + Class: 1, + TTL: 3600, + }, + }, + FromEntity: entity2, + ToEntity: entity1, + }) + assert.NoError(t, err) + + if edge.CreatedAt.Before(before) || edge.CreatedAt.After(after) { + t.Errorf("create time: %s, before time: %s, after time: %s", edge.CreatedAt.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } + if edge.LastSeen.Before(before) || edge.LastSeen.After(after) { + t.Errorf("create time: %s, before time: %s, after time: %s", edge.LastSeen.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } + if tags, err := c.cache.GetEdgeTags(edge, time.Time{}, "cache_create_edge"); err != nil || len(tags) != 1 { + t.Errorf("failed to create the cache tag:") + } + + time.Sleep(250 * time.Millisecond) + dbents, err := db2.FindEntityByContent(entity2.Asset, before) + assert.NoError(t, err) + + if num := len(dbents); num != 1 { + t.Errorf("failed to return the corrent number of entities: %d", num) + } + dbent := dbents[0] + + dbedges, err := db2.OutgoingEdges(dbent, before, "dns_record") + assert.NoError(t, err) + + if num := len(dbedges); num != 1 { + t.Errorf("failed to return the corrent number of edges: %d", num) + } + dbedge := dbedges[0] + + if !reflect.DeepEqual(edge.Relation, dbedge.Relation) { + t.Errorf("DeepEqual failed for the relations in the two edges") + } + if dbedge.CreatedAt.Before(before) || dbedge.CreatedAt.After(after) { + t.Errorf("create time: %s, before time: %s, after time: %s", dbedge.CreatedAt.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } + if dbedge.LastSeen.Before(before) || dbedge.LastSeen.After(after) { + t.Errorf("create time: %s, before time: %s, after time: %s", dbedge.LastSeen.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } +} + +func TestFindEdgeById(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + now := time.Now() + ctime := now.Add(-8 * time.Hour) + + entity1, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "owasp.org"}, + }) + assert.NoError(t, err) + + entity2, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "www.owasp.org"}, + }) + assert.NoError(t, err) + + edge, err := c.CreateEdge(&types.Edge{ + CreatedAt: ctime, + LastSeen: ctime, + Relation: &relation.BasicDNSRelation{ + Name: "dns_record", + Header: relation.RRHeader{ + RRType: 5, + Class: 1, + TTL: 3600, + }, + }, + FromEntity: entity2, + ToEntity: entity1, + }) + assert.NoError(t, err) + + e, err := c.FindEdgeById(edge.ID) + assert.NoError(t, err) + + if !reflect.DeepEqual(edge.Relation, e.Relation) { + t.Errorf("DeepEqual failed for the relation in the two edges") + } +} + +func TestIncomingEdges(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + now := time.Now() + ctime := now.Add(-8 * time.Hour) + before := ctime.Add(-2 * time.Second) + from, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "caffix.com"}, + }) + assert.NoError(t, err) + time.Sleep(250 * time.Millisecond) + + dbfrom, err := c.db.FindEntityByContent(from.Asset, time.Time{}) + assert.NoError(t, err) + + set1 := stringset.New() + defer set1.Close() + // add some old stuff to the database + var entities1 []*types.Entity + for _, name := range []string{"owasp.org", "utica.edu", "sunypoly.edu"} { + set1.Insert(name) + e, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: name}, + }) + assert.NoError(t, err) + _, err = c.db.CreateEdge(&types.Edge{ + CreatedAt: ctime, + LastSeen: ctime, + Relation: relation.SimpleRelation{Name: "node"}, + FromEntity: dbfrom[0], + ToEntity: e, + }) + assert.NoError(t, err) + entities1 = append(entities1, e) + } + + set2 := stringset.New() + defer set2.Close() + // add some new stuff to the database + var entities2 []*types.Entity + for _, name := range []string{"www.owasp.org", "www.utica.edu", "www.sunypoly.edu"} { + set2.Insert(name) + e, err := c.CreateAsset(&domain.FQDN{Name: name}) + assert.NoError(t, err) + _, err = c.CreateEdge(&types.Edge{ + Relation: relation.SimpleRelation{Name: "node"}, + FromEntity: from, + ToEntity: e, + }) + assert.NoError(t, err) + entities2 = append(entities2, e) + } + after := time.Now().Add(time.Second) + + // some tests that shouldn't return anything + _, err = c.IncomingEdges(entities2[0], after) + assert.Error(t, err) + // there shouldn't be a tag for this entity, since it didn't require the database + _, err = c.cache.GetEntityTags(entities2[0], time.Time{}, "cache_incoming_edges") + assert.Error(t, err) + + for _, entity := range entities2 { + edges, err := c.IncomingEdges(entity, c.StartTime(), "node") + assert.NoError(t, err) + if len(edges) != 1 { + t.Errorf("%s had the incorrect number of incoming edges", entity.Asset.Key()) + } + set2.Remove(entity.Asset.Key()) + } + + // only entities from set2 should have been removed + if set1.Len() != 3 || set2.Len() != 0 { + t.Errorf("first request failed to produce the correct edges") + } + // there shouldn't be a tag for this entity, since it didn't require the database + _, err = c.cache.GetEntityTags(entities2[0], time.Time{}, "cache_incoming_edges") + assert.Error(t, err) + + var rentity *types.Entity + for _, entity := range entities1 { + e, err := c.FindEntityByContent(entity.Asset, time.Time{}) + assert.NoError(t, err) + rentity = e[0] + edges, err := c.IncomingEdges(rentity, before, "node") + assert.NoError(t, err) + if len(edges) != 1 { + t.Errorf("%s had the incorrect number of incoming edges", rentity.Asset.Key()) + } + set1.Remove(rentity.Asset.Key()) + } + + // all entities should now be been removed + if set1.Len() != 0 || set2.Len() != 0 { + t.Errorf("second request failed to produce the correct entities") + } + // there should be a tag for this entity + tags, err := c.cache.GetEntityTags(rentity, time.Time{}, "cache_incoming_edges") + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("second request failed to produce the expected number of entity tags") + } + + tagtime, err := time.Parse(time.RFC3339Nano, tags[0].Property.Value()) + assert.NoError(t, err) + if tagtime.Before(before) || tagtime.After(after) { + t.Errorf("tag time: %s, before time: %s, after time: %s", tagtime.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } +} + +func TestOutgoingEdges(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + now := time.Now() + ctime := now.Add(-8 * time.Hour) + before := ctime.Add(-2 * time.Second) + from, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "caffix.com"}, + }) + assert.NoError(t, err) + time.Sleep(250 * time.Millisecond) + + dbfrom, err := c.db.FindEntityByContent(from.Asset, time.Time{}) + assert.NoError(t, err) + + set1 := stringset.New() + defer set1.Close() + // add some old stuff to the database + for _, name := range []string{"owasp.org", "utica.edu", "sunypoly.edu"} { + set1.Insert(name) + e, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: name}, + }) + assert.NoError(t, err) + _, err = c.db.CreateEdge(&types.Edge{ + CreatedAt: ctime, + LastSeen: ctime, + Relation: relation.SimpleRelation{Name: "node"}, + FromEntity: dbfrom[0], + ToEntity: e, + }) + assert.NoError(t, err) + } + + set2 := stringset.New() + defer set2.Close() + // add some new stuff to the database + for _, name := range []string{"www.owasp.org", "www.utica.edu", "www.sunypoly.edu"} { + set2.Insert(name) + e, err := c.CreateAsset(&domain.FQDN{Name: name}) + assert.NoError(t, err) + _, err = c.CreateEdge(&types.Edge{ + Relation: relation.SimpleRelation{Name: "node"}, + FromEntity: from, + ToEntity: e, + }) + assert.NoError(t, err) + } + after := time.Now().Add(time.Second) + + // some tests that shouldn't return anything + _, err = c.OutgoingEdges(from, after) + assert.Error(t, err) + // there shouldn't be a tag for this entity, since it didn't require the database + _, err = c.cache.GetEntityTags(from, time.Time{}, "cache_outgoing_edges") + assert.Error(t, err) + + edges, err := c.OutgoingEdges(from, c.StartTime(), "node") + assert.NoError(t, err) + if len(edges) != 3 { + t.Errorf("incorrect number of outgoing edges") + } + + for _, edge := range edges { + e, err := c.FindEntityById(edge.ToEntity.ID) + assert.NoError(t, err) + set2.Remove(e.Asset.Key()) + } + + // only entities from set2 should have been removed + if set1.Len() != 3 || set2.Len() != 0 { + t.Errorf("first request failed to produce the correct edges") + } + // there shouldn't be a tag for this entity, since it didn't require the database + _, err = c.cache.GetEntityTags(from, time.Time{}, "cache_outgoing_edges") + assert.Error(t, err) + + edges, err = c.OutgoingEdges(from, before, "node") + assert.NoError(t, err) + if len(edges) != 6 { + t.Errorf("incorrect number of outgoing edges") + } + + for _, edge := range edges { + e, err := c.FindEntityById(edge.ToEntity.ID) + assert.NoError(t, err) + set1.Remove(e.Asset.Key()) + } + + // all entities should now be been removed + if set1.Len() != 0 || set2.Len() != 0 { + t.Errorf("second request failed to produce the correct entities") + } + // there should be a tag for this entity + tags, err := c.cache.GetEntityTags(from, time.Time{}, "cache_outgoing_edges") + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("second request failed to produce the expected number of entity tags") + } + + tagtime, err := time.Parse(time.RFC3339Nano, tags[0].Property.Value()) + assert.NoError(t, err) + if tagtime.Before(before) || tagtime.After(after) { + t.Errorf("tag time: %s, before time: %s, after time: %s", tagtime.Format(time.RFC3339Nano), before.Format(time.RFC3339Nano), after.Format(time.RFC3339Nano)) + } +} + +func TestDeleteEdge(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + now := time.Now() + ctime := now.Add(-8 * time.Hour) + before := ctime.Add(-2 * time.Second) + + entity1, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "owasp.org"}, + }) + assert.NoError(t, err) + + entity2, err := c.CreateEntity(&types.Entity{ + CreatedAt: ctime, + LastSeen: ctime, + Asset: &domain.FQDN{Name: "www.owasp.org"}, + }) + assert.NoError(t, err) + + edge, err := c.CreateEdge(&types.Edge{ + CreatedAt: ctime, + LastSeen: ctime, + Relation: &relation.BasicDNSRelation{ + Name: "dns_record", + Header: relation.RRHeader{ + RRType: 5, + Class: 1, + TTL: 3600, + }, + }, + FromEntity: entity2, + ToEntity: entity1, + }) + assert.NoError(t, err) + + err = c.DeleteEdge(edge.ID) + assert.NoError(t, err) + + _, err = c.cache.FindEdgeById(edge.ID) + assert.Error(t, err) + + time.Sleep(250 * time.Millisecond) + dbent, err := c.db.FindEntityByContent(entity2.Asset, time.Time{}) + assert.NoError(t, err) + _, err = c.db.OutgoingEdges(dbent[0], before, edge.Relation.Label()) + assert.Error(t, err) +} diff --git a/cache/entity_test.go b/cache/entity_test.go index 124cab4..765b024 100644 --- a/cache/entity_test.go +++ b/cache/entity_test.go @@ -51,7 +51,7 @@ func TestCreateEntity(t *testing.T) { t.Errorf("failed to create the cache tag:") } - time.Sleep(time.Second) + time.Sleep(250 * time.Millisecond) dbents, err := db2.FindEntityByContent(entity.Asset, before) assert.NoError(t, err) @@ -100,7 +100,7 @@ func TestCreateAsset(t *testing.T) { t.Errorf("failed to create the cache tag:") } - time.Sleep(time.Second) + time.Sleep(250 * time.Millisecond) dbents, err := db2.FindEntityByContent(entity.Asset, now) assert.NoError(t, err) diff --git a/repository/sqlrepo/edge.go b/repository/sqlrepo/edge.go index 8ac6474..af76d81 100644 --- a/repository/sqlrepo/edge.go +++ b/repository/sqlrepo/edge.go @@ -162,6 +162,9 @@ func (sql *sqlRepository) IncomingEdges(entity *types.Entity, since time.Time, l results = edges } + if len(results) == 0 { + return nil, errors.New("zero edges found") + } return toEdges(results), nil } @@ -203,6 +206,9 @@ func (sql *sqlRepository) OutgoingEdges(entity *types.Entity, since time.Time, l results = edges } + if len(results) == 0 { + return nil, errors.New("zero edges found") + } return toEdges(results), nil }