diff --git a/cache/cache.go b/cache/cache.go index 6e15488..a9ff3ba 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -7,7 +7,6 @@ package cache import ( "time" - "github.com/caffix/queue" "github.com/owasp-amass/asset-db/repository" "github.com/owasp-amass/asset-db/types" "github.com/owasp-amass/open-asset-model/property" @@ -16,23 +15,18 @@ import ( type Cache struct { start time.Time freq time.Duration - done chan struct{} cache repository.Repository db repository.Repository - queue queue.Queue } func New(cache, database repository.Repository, freq time.Duration) (*Cache, error) { c := &Cache{ start: time.Now(), freq: freq, - done: make(chan struct{}, 1), cache: cache, db: database, - queue: queue.NewQueue(), } - go c.processDBCallbacks() return c, nil } @@ -43,15 +37,6 @@ func (c *Cache) StartTime() time.Time { // Close implements the Repository interface. func (c *Cache) Close() error { - close(c.done) - - for { - if c.queue.Empty() { - break - } - time.Sleep(2 * time.Second) - } - return c.cache.Close() } @@ -60,45 +45,6 @@ func (c *Cache) GetDBType() string { return c.db.GetDBType() } -func (c *Cache) appendToDBQueue(callback func()) { - select { - case <-c.done: - return - default: - } - c.queue.Append(callback) -} - -func (c *Cache) processDBCallbacks() { - t := time.NewTicker(100 * time.Millisecond) - defer t.Stop() -loop: - for { - select { - case <-c.done: - break loop - case <-c.queue.Signal(): - if element, ok := c.queue.Next(); ok { - if callback, success := element.(func()); success { - callback() - } - } - case <-t.C: - if element, ok := c.queue.Next(); ok { - if callback, success := element.(func()); success { - callback() - } - } - } - } - // execute the remaining callbacks in the queue - c.queue.Process(func(data interface{}) { - if callback, success := data.(func()); success { - callback() - } - }) -} - func (c *Cache) createCacheEntityTag(entity *types.Entity, name string, since time.Time) error { _, err := c.cache.CreateEntityProperty(entity, &property.SimpleProperty{ PropertyName: name, diff --git a/cache/edge.go b/cache/edge.go index 16f57b8..5bd2be6 100644 --- a/cache/edge.go +++ b/cache/edge.go @@ -34,28 +34,26 @@ func (c *Cache) CreateEdge(edge *types.Edge) (*types.Edge, error) { } _ = c.createCacheEdgeTag(e, "cache_create_edge", time.Now()) - c.appendToDBQueue(func() { - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return nil, err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return nil, err + } - _, _ = c.db.CreateEdge(&types.Edge{ - CreatedAt: edge.CreatedAt, - LastSeen: edge.LastSeen, - Relation: e.Relation, - FromEntity: s[0], - ToEntity: o[0], - }) + _, _ = c.db.CreateEdge(&types.Edge{ + CreatedAt: edge.CreatedAt, + LastSeen: edge.LastSeen, + Relation: e.Relation, + FromEntity: s[0], + ToEntity: o[0], }) } - return e, nil + return e, err } // FindEdgeById implements the Repository interface. @@ -81,22 +79,15 @@ func (c *Cache) IncomingEdges(entity *types.Entity, since time.Time, labels ...s var dberr error var dbedges []*types.Edge - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + dbedges, dberr = c.db.IncomingEdges(e[0], since) - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - dbedges, dberr = c.db.IncomingEdges(e[0], since) - - for i, edge := range dbedges { - if e, err := c.db.FindEntityById(edge.ToEntity.ID); err == nil && e != nil { - dbedges[i].ToEntity = e - } + for i, edge := range dbedges { + if e, err := c.db.FindEntityById(edge.ToEntity.ID); err == nil && e != nil { + dbedges[i].ToEntity = e } } - }) - <-done - close(done) + } if dberr == nil && len(dbedges) > 0 { for _, edge := range dbedges { @@ -140,22 +131,15 @@ func (c *Cache) OutgoingEdges(entity *types.Entity, since time.Time, labels ...s var dberr error var dbedges []*types.Edge - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() - - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - dbedges, dberr = c.db.OutgoingEdges(e[0], since) + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + dbedges, dberr = c.db.OutgoingEdges(e[0], since) - for i, edge := range dbedges { - if e, err := c.db.FindEntityById(edge.ToEntity.ID); err == nil && e != nil { - dbedges[i].ToEntity = e - } + for i, edge := range dbedges { + if e, err := c.db.FindEntityById(edge.ToEntity.ID); err == nil && e != nil { + dbedges[i].ToEntity = e } } - }) - <-done - close(done) + } if dberr == nil && len(dbedges) > 0 { for _, edge := range dbedges { @@ -202,33 +186,31 @@ func (c *Cache) DeleteEdge(id string) error { return err } - c.appendToDBQueue(func() { - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return err + } - edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) - if err != nil || len(edges) == 0 { - return - } + edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) + if err != nil || len(edges) == 0 { + return err + } - var target *types.Edge - for _, e := range edges { - if e.ToEntity.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { - target = e - break - } + var target *types.Edge + for _, e := range edges { + if e.ToEntity.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { + target = e + break } - if target != nil { - _ = c.db.DeleteEdge(target.ID) - } - }) + } + if target != nil { + err = c.db.DeleteEdge(target.ID) + } - return nil + return err } diff --git a/cache/edge_tag.go b/cache/edge_tag.go index e06e0a2..6e6b14c 100644 --- a/cache/edge_tag.go +++ b/cache/edge_tag.go @@ -5,7 +5,6 @@ package cache import ( - "errors" "reflect" "time" @@ -44,35 +43,33 @@ func (c *Cache) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag) (*types.Ed return nil, err } - c.appendToDBQueue(func() { - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return nil, err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return nil, err + } - edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) - if err != nil || len(edges) == 0 { - return - } + edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) + if err != nil || len(edges) == 0 { + return nil, err + } - var target *types.Edge - for _, e := range edges { - if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { - target = e - break - } - } - if target != nil { - _, _ = c.db.CreateEdgeProperty(target, input.Property) + var target *types.Edge + for _, e := range edges { + if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { + target = e + break } - }) + } + if target != nil { + _, err = c.db.CreateEdgeProperty(target, input.Property) + } - return tag, nil + return tag, err } // CreateEdgeProperty implements the Repository interface. @@ -106,35 +103,33 @@ func (c *Cache) CreateEdgeProperty(edge *types.Edge, property oam.Property) (*ty return nil, err } - c.appendToDBQueue(func() { - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return nil, err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return nil, err + } - edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) - if err != nil || len(edges) == 0 { - return - } + edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) + if err != nil || len(edges) == 0 { + return nil, err + } - var target *types.Edge - for _, e := range edges { - if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { - target = e - break - } - } - if target != nil { - _, _ = c.db.CreateEdgeProperty(target, property) + var target *types.Edge + for _, e := range edges { + if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { + target = e + break } - }) + } + if target != nil { + _, err = c.db.CreateEdgeProperty(target, property) + } - return tag, nil + return tag, err } // FindEdgeTagById implements the Repository interface. @@ -144,24 +139,11 @@ func (c *Cache) FindEdgeTagById(id string) (*types.EdgeTag, error) { // FindEdgeTagsByContent implements the Repository interface. func (c *Cache) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) { - tags, err := c.cache.FindEdgeTagsByContent(prop, since) - if err == nil && len(tags) > 0 { - return tags, nil - } - - if !since.IsZero() && !since.Before(c.start) { - return nil, err - } - - var dberr error - var dbedges []*types.Edge - var dbtags []*types.EdgeTag - var froms, tos []*types.Entity - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() + if since.IsZero() || since.Before(c.start) { + var dbedges []*types.Edge + var froms, tos []*types.Entity - dbtags, dberr = c.db.FindEdgeTagsByContent(prop, since) + dbtags, dberr := c.db.FindEdgeTagsByContent(prop, since) if dberr == nil && len(dbtags) > 0 { for _, tag := range dbtags { if edge, err := c.db.FindEdgeById(tag.Edge.ID); err == nil && edge != nil { @@ -179,58 +161,48 @@ func (c *Cache) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*ty } } } - }) - <-done - close(done) - - if dberr != nil { - return tags, err - } - - var results []*types.EdgeTag - for i, tag := range dbtags { - from, err := c.cache.CreateEntity(&types.Entity{ - CreatedAt: froms[i].CreatedAt, - LastSeen: froms[i].LastSeen, - Asset: froms[i].Asset, - }) - if err != nil || from == nil { - continue - } - to, err := c.cache.CreateEntity(&types.Entity{ - CreatedAt: tos[i].CreatedAt, - LastSeen: tos[i].LastSeen, - Asset: tos[i].Asset, - }) - if err != nil || to == nil { - continue - } + if dberr == nil { + for i, tag := range dbtags { + from, err := c.cache.CreateEntity(&types.Entity{ + CreatedAt: froms[i].CreatedAt, + LastSeen: froms[i].LastSeen, + Asset: froms[i].Asset, + }) + if err != nil || from == nil { + continue + } - edge, err := c.cache.CreateEdge(&types.Edge{ - CreatedAt: dbedges[i].CreatedAt, - LastSeen: dbedges[i].LastSeen, - Relation: dbedges[i].Relation, - FromEntity: from, - ToEntity: to, - }) - if err != nil || edge == nil { - continue - } + to, err := c.cache.CreateEntity(&types.Entity{ + CreatedAt: tos[i].CreatedAt, + LastSeen: tos[i].LastSeen, + Asset: tos[i].Asset, + }) + if err != nil || to == nil { + continue + } + + edge, err := c.cache.CreateEdge(&types.Edge{ + CreatedAt: dbedges[i].CreatedAt, + LastSeen: dbedges[i].LastSeen, + Relation: dbedges[i].Relation, + FromEntity: from, + ToEntity: to, + }) + if err != nil || edge == nil { + continue + } - if e, err := c.cache.CreateEdgeTag(edge, &types.EdgeTag{ - CreatedAt: tag.CreatedAt, - LastSeen: tag.LastSeen, - Property: tag.Property, - }); err == nil { - results = append(results, e) + _, _ = c.cache.CreateEdgeTag(edge, &types.EdgeTag{ + CreatedAt: tag.CreatedAt, + LastSeen: tag.LastSeen, + Property: tag.Property, + }) + } } } - if len(results) == 0 { - return nil, errors.New("zero edge tags found") - } - return results, nil + return c.cache.FindEdgeTagsByContent(prop, since) } // GetEdgeTags implements the Repository interface. @@ -260,38 +232,32 @@ func (c *Cache) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) var dberr error var dbtags []*types.EdgeTag - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return nil, err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return nil, err + } - edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) - if err != nil || len(edges) == 0 { - return - } + edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label()) + if err != nil || len(edges) == 0 { + return nil, err + } - var target *types.Edge - for _, e := range edges { - if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { - target = e - break - } - } - if target != nil { - dbtags, dberr = c.db.GetEdgeTags(target, since) + var target *types.Edge + for _, e := range edges { + if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) { + target = e + break } - }) - <-done - close(done) + } + if target != nil { + dbtags, dberr = c.db.GetEdgeTags(target, since) + } if dberr == nil && len(dbtags) > 0 { for _, tag := range dbtags { @@ -333,41 +299,39 @@ func (c *Cache) DeleteEdgeTag(id string) error { return err } - c.appendToDBQueue(func() { - s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) - if err != nil || len(s) != 1 { - return - } + s, err := c.db.FindEntityByContent(sub.Asset, time.Time{}) + if err != nil || len(s) != 1 { + return err + } - o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) - if err != nil || len(o) != 1 { - return - } + o, err := c.db.FindEntityByContent(obj.Asset, time.Time{}) + if err != nil || len(o) != 1 { + return err + } - edges, err := c.db.OutgoingEdges(s[0], time.Time{}, tag.Edge.Relation.Label()) - if err != nil || len(edges) == 0 { - return - } + edges, err := c.db.OutgoingEdges(s[0], time.Time{}, tag.Edge.Relation.Label()) + if err != nil || len(edges) == 0 { + return err + } - var target *types.Edge - for _, e := range edges { - if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { - target = e - break - } - } - if target == nil { - return + var target *types.Edge + for _, e := range edges { + if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge2.Relation) { + target = e + break } + } + if target == nil { + return err + } - if tags, err := c.db.GetEdgeTags(target, time.Time{}, tag.Property.Name()); err == nil && len(tags) > 0 { - for _, t := range tags { - if tag.Property.Value() == t.Property.Value() { - _ = c.db.DeleteEdgeTag(t.ID) - } + if tags, err := c.db.GetEdgeTags(target, time.Time{}, tag.Property.Name()); err == nil && len(tags) > 0 { + for _, t := range tags { + if tag.Property.Value() == t.Property.Value() { + _ = c.db.DeleteEdgeTag(t.ID) } } - }) + } return nil } diff --git a/cache/edge_tag_test.go b/cache/edge_tag_test.go index 8c609f9..6e4b101 100644 --- a/cache/edge_tag_test.go +++ b/cache/edge_tag_test.go @@ -177,6 +177,81 @@ func TestFindEdgeTagById(t *testing.T) { } } +func TestFindEdgeTagsByContent(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2, time.Minute) + assert.NoError(t, err) + defer c.Close() + + // add some really old stuff to the database + now := time.Now() + ctime1 := now.Add(-24 * time.Hour) + cbefore1 := ctime1.Add(-20 * time.Second) + edge, err := createTestEdge(c, ctime1) + assert.NoError(t, err) + prop1 := &property.SimpleProperty{ + PropertyName: "test1", + PropertyValue: "foobar", + } + _, err = c.CreateEdgeTag(edge, &types.EdgeTag{ + CreatedAt: ctime1, + LastSeen: ctime1, + Property: prop1, + Edge: edge, + }) + assert.NoError(t, err) + // add some not so old stuff to the database + ctime2 := now.Add(-8 * time.Hour) + cbefore2 := ctime2.Add(-20 * time.Second) + prop2 := &property.SimpleProperty{ + PropertyName: "test2", + PropertyValue: "foobar", + } + _, err = c.CreateEdgeTag(edge, &types.EdgeTag{ + CreatedAt: ctime2, + LastSeen: ctime2, + Property: prop2, + Edge: edge, + }) + assert.NoError(t, err) + // add new entities to the database + prop3 := &property.SimpleProperty{ + PropertyName: "test3", + PropertyValue: "foobar", + } + _, err = c.CreateEdgeProperty(edge, prop3) + assert.NoError(t, err) + after := time.Now().Add(time.Second) + + _, err = c.FindEdgeTagsByContent(prop3, after) + assert.Error(t, err) + + tags, err := c.FindEdgeTagsByContent(prop3, c.StartTime()) + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("first request failed to produce the expected number of tags") + } + + tags, err = c.FindEdgeTagsByContent(prop2, cbefore2) + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("second request failed to produce the expected number of tags") + } + + tags, err = c.FindEdgeTagsByContent(prop1, cbefore1) + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("third request failed to produce the expected number of tags") + } +} + func TestGetEdgeTags(t *testing.T) { db1, db2, dir, err := createTestRepositories() assert.NoError(t, err) diff --git a/cache/entity.go b/cache/entity.go index f4cd3d4..1a38543 100644 --- a/cache/entity.go +++ b/cache/entity.go @@ -25,16 +25,14 @@ func (c *Cache) CreateEntity(input *types.Entity) (*types.Entity, error) { } _ = c.createCacheEntityTag(entity, "cache_create_entity", time.Now()) - c.appendToDBQueue(func() { - _, _ = c.db.CreateEntity(&types.Entity{ - CreatedAt: input.CreatedAt, - LastSeen: input.LastSeen, - Asset: input.Asset, - }) + _, err = c.db.CreateEntity(&types.Entity{ + CreatedAt: input.CreatedAt, + LastSeen: input.LastSeen, + Asset: input.Asset, }) } - return entity, nil + return entity, err } // CreateAsset implements the Repository interface. @@ -50,12 +48,10 @@ func (c *Cache) CreateAsset(asset oam.Asset) (*types.Entity, error) { } _ = c.createCacheEntityTag(entity, "cache_create_asset", time.Now()) - c.appendToDBQueue(func() { - _, _ = c.db.CreateAsset(asset) - }) + _, err = c.db.CreateAsset(asset) } - return entity, nil + return entity, err } // FindEntityById implements the Repository interface. @@ -74,17 +70,7 @@ func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types. return nil, err } - var dberr error - var dbentities []*types.Entity - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() - - dbentities, dberr = c.db.FindEntityByContent(asset, since) - }) - <-done - close(done) - + dbentities, dberr := c.db.FindEntityByContent(asset, since) if dberr != nil { return entities, err } @@ -118,17 +104,7 @@ func (c *Cache) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*typ } } - var dberr error - var dbentities []*types.Entity - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() - - dbentities, dberr = c.db.FindEntitiesByType(atype, since) - }) - <-done - close(done) - + dbentities, dberr := c.db.FindEntitiesByType(atype, since) if dberr != nil { return entities, err } @@ -164,11 +140,9 @@ func (c *Cache) DeleteEntity(id string) error { return err } - c.appendToDBQueue(func() { - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - _ = c.db.DeleteEntity(e[0].ID) - } - }) + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + _ = c.db.DeleteEntity(e[0].ID) + } return nil } diff --git a/cache/entity_tag.go b/cache/entity_tag.go index 130b54f..daac8be 100644 --- a/cache/entity_tag.go +++ b/cache/entity_tag.go @@ -5,7 +5,6 @@ package cache import ( - "errors" "time" "github.com/owasp-amass/asset-db/types" @@ -28,15 +27,13 @@ func (c *Cache) CreateEntityTag(entity *types.Entity, input *types.EntityTag) (* return nil, err } - c.appendToDBQueue(func() { - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - _, _ = c.db.CreateEntityTag(e[0], &types.EntityTag{ - CreatedAt: input.CreatedAt, - LastSeen: input.LastSeen, - Property: input.Property, - }) - } - }) + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + _, _ = c.db.CreateEntityTag(e[0], &types.EntityTag{ + CreatedAt: input.CreatedAt, + LastSeen: input.LastSeen, + Property: input.Property, + }) + } return tag, nil } @@ -57,11 +54,9 @@ func (c *Cache) CreateEntityProperty(entity *types.Entity, property oam.Property return nil, err } - c.appendToDBQueue(func() { - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - _, _ = c.db.CreateEntityProperty(e[0], property) - } - }) + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + _, _ = c.db.CreateEntityProperty(e[0], property) + } return tag, nil } @@ -73,23 +68,10 @@ func (c *Cache) FindEntityTagById(id string) (*types.EntityTag, error) { // FindEntityTagsByContent implements the Repository interface. func (c *Cache) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) { - tags, err := c.cache.FindEntityTagsByContent(prop, since) - if err == nil && len(tags) > 0 { - return tags, nil - } - - if !since.IsZero() && !since.Before(c.start) { - return nil, err - } - - var dberr error - var dbtags []*types.EntityTag - var dbentities []*types.Entity - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() + if since.IsZero() || since.Before(c.start) { + var dbentities []*types.Entity - dbtags, dberr = c.db.FindEntityTagsByContent(prop, since) + dbtags, dberr := c.db.FindEntityTagsByContent(prop, since) if dberr == nil && len(dbtags) > 0 { for _, tag := range dbtags { if entity, err := c.db.FindEntityById(tag.Entity.ID); err == nil && entity != nil { @@ -97,34 +79,26 @@ func (c *Cache) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]* } } } - }) - <-done - close(done) - if dberr != nil { - return tags, err - } - - var results []*types.EntityTag - for i, tag := range dbtags { - entity, err := c.cache.CreateEntity(dbentities[i]) - if err != nil || entity == nil { - continue - } - - if e, err := c.cache.CreateEntityTag(entity, &types.EntityTag{ - CreatedAt: tag.CreatedAt, - LastSeen: tag.LastSeen, - Property: tag.Property, - }); err == nil { - results = append(results, e) + if dberr == nil { + for i, tag := range dbtags { + if entity, err := c.cache.CreateEntity(&types.Entity{ + CreatedAt: dbentities[i].CreatedAt, + LastSeen: dbentities[i].LastSeen, + Asset: dbentities[i].Asset, + }); err == nil && entity != nil { + _, _ = c.cache.CreateEntityTag(entity, &types.EntityTag{ + CreatedAt: tag.CreatedAt, + LastSeen: tag.LastSeen, + Property: tag.Property, + Entity: entity, + }) + } + } } } - if len(results) == 0 { - return nil, errors.New("zero entity tags found") - } - return results, nil + return c.cache.FindEntityTagsByContent(prop, since) } // GetEntityTags implements the Repository interface. @@ -145,16 +119,9 @@ func (c *Cache) GetEntityTags(entity *types.Entity, since time.Time, names ...st var dberr error var dbtags []*types.EntityTag - done := make(chan struct{}, 1) - c.appendToDBQueue(func() { - defer func() { done <- struct{}{} }() - - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - dbtags, dberr = c.db.GetEntityTags(e[0], since) - } - }) - <-done - close(done) + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + dbtags, dberr = c.db.GetEntityTags(e[0], since) + } if dberr == nil && len(dbtags) > 0 { for _, tag := range dbtags { @@ -186,17 +153,15 @@ func (c *Cache) DeleteEntityTag(id string) error { return err } - c.appendToDBQueue(func() { - if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { - if tags, err := c.db.GetEntityTags(e[0], time.Time{}, tag.Property.Name()); err == nil && len(tags) > 0 { - for _, t := range tags { - if t.Property.Value() == tag.Property.Value() { - _ = c.db.DeleteEntityTag(t.ID) - } + if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 { + if tags, err := c.db.GetEntityTags(e[0], time.Time{}, tag.Property.Name()); err == nil && len(tags) > 0 { + for _, t := range tags { + if t.Property.Value() == tag.Property.Value() { + _ = c.db.DeleteEntityTag(t.ID) } } } - }) + } return nil } diff --git a/cache/entity_tag_test.go b/cache/entity_tag_test.go index 526591b..b80f5b3 100644 --- a/cache/entity_tag_test.go +++ b/cache/entity_tag_test.go @@ -168,6 +168,85 @@ func TestFindEntityTagById(t *testing.T) { } } +func TestFindEntityTagsByContent(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2, time.Minute) + assert.NoError(t, err) + defer c.Close() + + // add some really old stuff to the database + now := time.Now() + prop := &property.SimpleProperty{ + PropertyName: "test", + PropertyValue: "foobar", + } + ctime1 := now.Add(-24 * time.Hour) + cbefore1 := ctime1.Add(-20 * time.Second) + fqdn1 := &domain.FQDN{Name: "owasp.org"} + entity1, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime1, + LastSeen: ctime1, + Asset: fqdn1, + }) + assert.NoError(t, err) + _, err = c.db.CreateEntityTag(entity1, &types.EntityTag{ + CreatedAt: ctime1, + LastSeen: ctime1, + Property: prop, + }) + assert.NoError(t, err) + // add some not so old stuff to the database + ctime2 := now.Add(-8 * time.Hour) + cbefore2 := ctime2.Add(-20 * time.Second) + fqdn2 := &domain.FQDN{Name: "utica.edu"} + entity2, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime2, + LastSeen: ctime2, + Asset: fqdn2, + }) + assert.NoError(t, err) + _, err = c.db.CreateEntityTag(entity2, &types.EntityTag{ + CreatedAt: ctime2, + LastSeen: ctime2, + Property: prop, + }) + assert.NoError(t, err) + // add new entities to the database + entity3, err := c.CreateAsset(&domain.FQDN{Name: "sunypoly.edu"}) + assert.NoError(t, err) + _, err = c.CreateEntityProperty(entity3, prop) + assert.NoError(t, err) + after := time.Now().Add(time.Second) + + _, err = c.FindEntityTagsByContent(prop, after) + assert.Error(t, err) + + tags, err := c.FindEntityTagsByContent(prop, c.StartTime()) + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("first request failed to produce the expected number of tags") + } + + tags, err = c.FindEntityTagsByContent(prop, cbefore2) + assert.NoError(t, err) + if len(tags) != 2 { + t.Errorf("second request failed to produce the expected number of tags") + } + + tags, err = c.FindEntityTagsByContent(prop, cbefore1) + assert.NoError(t, err) + if len(tags) != 3 { + t.Errorf("third request failed to produce the expected number of tags") + } +} + func TestGetEntityTags(t *testing.T) { db1, db2, dir, err := createTestRepositories() assert.NoError(t, err) diff --git a/repository/sqlrepo/db.go b/repository/sqlrepo/db.go index 8aa9cf5..c47756b 100644 --- a/repository/sqlrepo/db.go +++ b/repository/sqlrepo/db.go @@ -6,6 +6,7 @@ package sqlrepo import ( "errors" + "time" "github.com/glebarez/sqlite" "gorm.io/driver/postgres" @@ -53,12 +54,40 @@ func newDatabase(dbtype, dsn string) (*gorm.DB, error) { // postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn). func postgresDatabase(dsn string) (*gorm.DB, error) { - return gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + if err != nil { + return nil, err + } + + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + + sqlDB.SetMaxIdleConns(10) + sqlDB.SetMaxOpenConns(20) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(10 * time.Minute) + return db, nil } // sqliteDatabase creates a new SQLite database connection using the provided data source name (dsn). func sqliteDatabase(dsn string) (*gorm.DB, error) { - return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + if err != nil { + return nil, err + } + + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + + sqlDB.SetMaxIdleConns(5) + sqlDB.SetMaxOpenConns(3) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(10 * time.Minute) + return db, nil } // Close implements the Repository interface. diff --git a/repository/sqlrepo/tag.go b/repository/sqlrepo/tag.go index 6616958..1d1dcc5 100644 --- a/repository/sqlrepo/tag.go +++ b/repository/sqlrepo/tag.go @@ -166,6 +166,7 @@ func (sql *sqlRepository) FindEntityTagsByContent(prop oam.Property, since time. CreatedAt: t.CreatedAt.In(time.UTC).Local(), LastSeen: t.UpdatedAt.In(time.UTC).Local(), Property: propData, + Entity: &types.Entity{ID: strconv.FormatUint(t.EntityID, 10)}, }) } } @@ -407,6 +408,7 @@ func (sql *sqlRepository) FindEdgeTagsByContent(prop oam.Property, since time.Ti CreatedAt: t.CreatedAt.In(time.UTC).Local(), LastSeen: t.UpdatedAt.In(time.UTC).Local(), Property: propData, + Edge: &types.Edge{ID: strconv.FormatUint(t.EdgeID, 10)}, }) } }