diff --git a/cache/entity.go b/cache/entity.go index c82c96e..caf73cb 100644 --- a/cache/entity.go +++ b/cache/entity.go @@ -5,6 +5,7 @@ package cache import ( + "errors" "time" "github.com/owasp-amass/asset-db/types" @@ -69,17 +70,14 @@ func (c *Cache) FindEntityById(id string) (*types.Entity, error) { func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) { c.Lock() entities, err := c.cache.FindEntityByContent(asset, since) - if err == nil && len(entities) == 1 { - if !since.IsZero() && !since.Before(c.start) { - c.Unlock() - return entities, err - } - if _, last, found := c.checkCacheEntityTag(entities[0], "cache_find_entity_by_content"); found && !since.Before(last) { - c.Unlock() - return entities, err - } + if err == nil && len(entities) >= 1 { + c.Unlock() + return entities, nil } c.Unlock() + if !since.IsZero() && !since.Before(c.start) { + return nil, err + } var dberr error var dbentities []*types.Entity @@ -107,14 +105,12 @@ func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types. Asset: entity.Asset, }); err == nil { results = append(results, e) - if tags, err := c.cache.GetEntityTags(entity, c.start, "cache_find_entity_by_content"); err == nil && len(tags) > 0 { - for _, tag := range tags { - _ = c.cache.DeleteEntityTag(tag.ID) - } - } - _ = c.createCacheEntityTag(entity, "cache_find_entity_by_content", since) } } + + if len(results) == 0 { + return nil, errors.New("zero entities found") + } return results, nil } diff --git a/cache/entity_test.go b/cache/entity_test.go index d01eddb..124cab4 100644 --- a/cache/entity_test.go +++ b/cache/entity_test.go @@ -10,7 +10,9 @@ import ( "testing" "time" + "github.com/caffix/stringset" "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" "github.com/owasp-amass/open-asset-model/domain" "github.com/stretchr/testify/assert" ) @@ -142,6 +144,238 @@ func TestFindEntityById(t *testing.T) { } } +func TestFindEntityByContent(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + // add some really old stuff to the database + now := time.Now() + ctime1 := now.Add(-24 * time.Hour) + cbefore1 := ctime1.Add(-20 * time.Second) + fqdn1 := &domain.FQDN{Name: "owasp.org"} + entity1, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime1, + LastSeen: ctime1, + Asset: fqdn1, + }) + assert.NoError(t, err) + // add some not so old stuff to the database + ctime2 := now.Add(-8 * time.Hour) + cbefore2 := ctime2.Add(-20 * time.Second) + fqdn2 := &domain.FQDN{Name: "utica.edu"} + entity2, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime2, + LastSeen: ctime2, + Asset: fqdn2, + }) + assert.NoError(t, err) + // add new entities to the database + fqdn3 := &domain.FQDN{Name: "sunypoly.edu"} + entity3, err := c.CreateEntity(&types.Entity{ + CreatedAt: now, + LastSeen: now, + Asset: fqdn3, + }) + assert.NoError(t, err) + after := time.Now().Add(2 * time.Second) + + _, err = c.FindEntityByContent(fqdn3, after) + assert.Error(t, err) + + entities, err := c.FindEntityByContent(fqdn3, now) + assert.NoError(t, err) + if len(entities) != 1 { + t.Errorf("first request failed to produce the expected number of entities") + } + + e := entities[0] + if !reflect.DeepEqual(e.Asset, entity3.Asset) { + t.Errorf("DeepEqual failed for the assets in the two entities") + } + + _, err = c.FindEntityByContent(fqdn2, c.StartTime()) + assert.Error(t, err) + + entities, err = c.FindEntityByContent(fqdn2, cbefore2) + assert.NoError(t, err) + if len(entities) != 1 { + t.Errorf("second request failed to produce the expected number of entities") + } + + e = entities[0] + if !reflect.DeepEqual(e.Asset, entity2.Asset) { + t.Errorf("DeepEqual failed for the assets in the two entities") + } + + _, err = c.FindEntityByContent(fqdn1, cbefore2) + assert.Error(t, err) + + entities, err = c.FindEntityByContent(fqdn1, cbefore1) + assert.NoError(t, err) + if len(entities) != 1 { + t.Errorf("third request failed to produce the expected number of entities") + } + + e = entities[0] + if !reflect.DeepEqual(e.Asset, entity1.Asset) { + t.Errorf("DeepEqual failed for the assets in the two entities") + } +} + +func TestFindEntitiesByType(t *testing.T) { + db1, db2, dir, err := createTestRepositories() + assert.NoError(t, err) + defer func() { + db1.Close() + db2.Close() + os.RemoveAll(dir) + }() + + c, err := New(db1, db2) + assert.NoError(t, err) + defer c.Close() + + set1 := stringset.New() + defer set1.Close() + // add some really old stuff to the database + now := time.Now() + ctime1 := now.Add(-24 * time.Hour) + cbefore1 := ctime1.Add(-20 * time.Second) + cafter1 := ctime1.Add(20 * time.Second) + for _, name := range []string{"owasp.org", "utica.edu", "sunypoly.edu"} { + set1.Insert(name) + _, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime1, + LastSeen: ctime1, + Asset: &domain.FQDN{Name: name}, + }) + assert.NoError(t, err) + } + + set2 := stringset.New() + defer set2.Close() + // add some not so old stuff to the database + ctime2 := now.Add(-8 * time.Hour) + cbefore2 := ctime2.Add(-20 * time.Second) + cafter2 := ctime2.Add(20 * time.Second) + for _, name := range []string{"www.owasp.org", "www.utica.edu", "www.sunypoly.edu"} { + set2.Insert(name) + _, err := c.db.CreateEntity(&types.Entity{ + CreatedAt: ctime2, + LastSeen: ctime2, + Asset: &domain.FQDN{Name: name}, + }) + assert.NoError(t, err) + } + + set3 := stringset.New() + defer set3.Close() + // add new entities to the database + after := now.Add(20 * time.Second) + for _, name := range []string{"ns1.owasp.org", "ns1.utica.edu", "ns1.sunypoly.edu"} { + set3.Insert(name) + _, err := c.CreateAsset(&domain.FQDN{Name: name}) + assert.NoError(t, err) + } + + // no results should be produced with this since param + _, err = c.FindEntitiesByType(oam.FQDN, after) + assert.Error(t, err) + + entities, err := c.FindEntitiesByType(oam.FQDN, c.StartTime()) + assert.NoError(t, err) + if len(entities) != 3 { + t.Errorf("first request failed to produce the expected number of entities") + } + + for _, entity := range entities { + if fqdn, ok := entity.Asset.(*domain.FQDN); ok { + set1.Remove(fqdn.Name) + set2.Remove(fqdn.Name) + set3.Remove(fqdn.Name) + } + } + + // only entities from set3 should have been removed + if set1.Len() != 3 || set2.Len() != 3 || set3.Len() != 0 { + t.Errorf("first request failed to produce the correct entities") + } + // there shouldn't be a tag for this entity, since it didn't require the database + _, err = c.cache.GetEntityTags(entities[0], now, "cache_find_entities_by_type") + assert.Error(t, err) + + entities, err = c.FindEntitiesByType(oam.FQDN, ctime2) + assert.NoError(t, err) + if len(entities) != 6 { + t.Errorf("second request failed to produce the expected number of entities") + } + + for _, entity := range entities { + if fqdn, ok := entity.Asset.(*domain.FQDN); ok { + set1.Remove(fqdn.Name) + set2.Remove(fqdn.Name) + set3.Remove(fqdn.Name) + } + } + + // only entities from set3 should have been removed + if set1.Len() != 3 || set2.Len() != 0 || set3.Len() != 0 { + t.Errorf("second request failed to produce the correct entities") + } + // there should be a tag for this entity + tags, err := c.cache.GetEntityTags(entities[0], time.Time{}, "cache_find_entities_by_type") + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("second request failed to produce the expected number of entity tags") + } + + tagtime, err := time.Parse(time.RFC3339Nano, tags[0].Property.Value()) + assert.NoError(t, err) + if tagtime.Before(cbefore2) || tagtime.After(cafter2) { + t.Errorf("tag time: %s, before time: %s, after time: %s", tagtime.Format(time.RFC3339Nano), cbefore2.Format(time.RFC3339Nano), cafter2.Format(time.RFC3339Nano)) + } + + entities, err = c.FindEntitiesByType(oam.FQDN, ctime1) + assert.NoError(t, err) + if len(entities) != 9 { + t.Errorf("third request failed to produce the expected number of entities") + } + + for _, entity := range entities { + if fqdn, ok := entity.Asset.(*domain.FQDN); ok { + set1.Remove(fqdn.Name) + set2.Remove(fqdn.Name) + set3.Remove(fqdn.Name) + } + } + + // only entities from set3 should have been removed + if set1.Len() != 0 || set2.Len() != 0 || set3.Len() != 0 { + t.Errorf("third request failed to produce the correct entities") + } + // there should now be a new tag for this entity + tags, err = c.cache.GetEntityTags(entities[0], time.Time{}, "cache_find_entities_by_type") + assert.NoError(t, err) + if len(tags) != 1 { + t.Errorf("third request failed to produce the expected number of entity tags") + } + + tagtime, err = time.Parse(time.RFC3339Nano, tags[0].Property.Value()) + assert.NoError(t, err) + if tagtime.Before(cbefore1) || tagtime.After(cafter1) { + t.Errorf("tag time: %s, before time: %s, after time: %s", tagtime.Format(time.RFC3339Nano), cbefore1.Format(time.RFC3339Nano), cafter1.Format(time.RFC3339Nano)) + } +} + func TestDeleteEntity(t *testing.T) { db1, db2, dir, err := createTestRepositories() assert.NoError(t, err) diff --git a/cache/tag.go b/cache/tag.go index b22de85..2d99a31 100644 --- a/cache/tag.go +++ b/cache/tag.go @@ -418,14 +418,14 @@ func (c *Cache) DeleteEdgeTag(id string) error { func (c *Cache) createCacheEntityTag(entity *types.Entity, name string, since time.Time) error { _, err := c.cache.CreateEntityProperty(entity, &property.SimpleProperty{ PropertyName: name, - PropertyValue: time.Now().Format("2006-01-02 15:04:05"), + PropertyValue: since.Format(time.RFC3339Nano), }) return err } func (c *Cache) checkCacheEntityTag(entity *types.Entity, name string) (*types.EntityTag, time.Time, bool) { if tags, err := c.cache.GetEntityTags(entity, time.Time{}, name); err == nil && len(tags) == 1 { - if t, err := time.Parse("2006-01-02 15:04:05", tags[0].Property.Value()); err == nil { + if t, err := time.Parse(time.RFC3339Nano, tags[0].Property.Value()); err == nil { return tags[0], t, true } } @@ -435,14 +435,14 @@ func (c *Cache) checkCacheEntityTag(entity *types.Entity, name string) (*types.E func (c *Cache) createCacheEdgeTag(edge *types.Edge, name string, since time.Time) error { _, err := c.cache.CreateEdgeProperty(edge, &property.SimpleProperty{ PropertyName: name, - PropertyValue: time.Now().Format("2006-01-02 15:04:05"), + PropertyValue: since.Format(time.RFC3339Nano), }) return err } func (c *Cache) checkCacheEdgeTag(edge *types.Edge, name string) (*types.EdgeTag, time.Time, bool) { if tags, err := c.cache.GetEdgeTags(edge, time.Time{}, name); err == nil && len(tags) == 1 { - if t, err := time.Parse("2006-01-02 15:04:05", tags[0].Property.Value()); err == nil { + if t, err := time.Parse(time.RFC3339Nano, tags[0].Property.Value()); err == nil { return tags[0], t, true } } diff --git a/go.mod b/go.mod index 064590a..5046868 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,13 @@ go 1.23.1 require ( github.com/caffix/queue v0.3.1 + github.com/caffix/stringset v0.2.0 github.com/glebarez/sqlite v1.11.0 github.com/owasp-amass/open-asset-model v0.12.0 github.com/rubenv/sql-migrate v1.7.0 github.com/stretchr/testify v1.9.0 gorm.io/datatypes v1.2.4 gorm.io/driver/postgres v1.5.9 - gorm.io/driver/sqlite v1.5.6 gorm.io/gorm v1.25.12 ) @@ -41,6 +41,7 @@ require ( golang.org/x/text v0.20.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/mysql v1.5.7 // indirect + gorm.io/driver/sqlite v1.5.6 // indirect modernc.org/libc v1.61.2 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect diff --git a/go.sum b/go.sum index 1cc1172..4a78be1 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,6 @@ modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= modernc.org/gc/v2 v2.5.0 h1:bJ9ChznK1L1mUtAQtxi0wi5AtAs5jQuw4PrPHO5pb6M= modernc.org/gc/v2 v2.5.0/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= -modernc.org/libc v1.61.1 h1:F8JngdWfVzqfNpff2apn7JpBkjq1ss8Ue4KuUdLDM7Q= -modernc.org/libc v1.61.1/go.mod h1:4QGjNyX3h+rn7V5oHpJY2yH0QN6frt1X+5BkXzwLPCo= modernc.org/libc v1.61.2 h1:dkO4DlowfClcJYsvf/RiK6fUwvzCQTmB34bJLt0CAGQ= modernc.org/libc v1.61.2/go.mod h1:4QGjNyX3h+rn7V5oHpJY2yH0QN6frt1X+5BkXzwLPCo= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= diff --git a/repository/sqlrepo/tag.go b/repository/sqlrepo/tag.go index 259f0a0..de086e0 100644 --- a/repository/sqlrepo/tag.go +++ b/repository/sqlrepo/tag.go @@ -5,6 +5,7 @@ package sqlrepo import ( + "errors" "strconv" "time" @@ -171,6 +172,9 @@ func (sql *sqlRepository) GetEntityTags(entity *types.Entity, since time.Time, n } } + if len(results) == 0 { + return nil, errors.New("zero tags found") + } return results, nil } @@ -354,6 +358,9 @@ func (sql *sqlRepository) GetEdgeTags(edge *types.Edge, since time.Time, names . } } + if len(results) == 0 { + return nil, errors.New("zero tags found") + } return results, nil }