Skip to content

Commit

Permalink
extended the Repository interface and continued cache development
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Nov 18, 2024
1 parent 46f5a8c commit fa1e376
Show file tree
Hide file tree
Showing 17 changed files with 371 additions and 246 deletions.
27 changes: 11 additions & 16 deletions assetdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,19 @@ type AssetDB struct {
// If the edge is provided, the entity is created and linked to the source entity using the specified edge.
// It returns the newly created entity and an error, if any.
func (as *AssetDB) Create(edge *types.Edge, asset oam.Asset) (*types.Entity, error) {
e, err := as.Repo.CreateEntity(asset)
e, err := as.Repo.CreateAsset(asset)
if err != nil || edge == nil || edge.FromEntity == nil || edge.Relation == nil {
return e, err
}

edge.ToEntity = e
_, err = as.Repo.Link(edge)
_, err = as.Repo.CreateEdge(edge)
if err != nil {
return nil, err
}
return e, nil
}

// UpdateEntityLastSeen updates the entity last seen field to the current time by its ID.
func (as *AssetDB) UpdateEntityLastSeen(id string) error {
return as.Repo.UpdateEntityLastSeen(id)
}

// DeleteEntity removes an entity in the database by its ID.
func (as *AssetDB) DeleteEntity(id string) error {
return as.Repo.DeleteEntity(id)
Expand Down Expand Up @@ -69,11 +64,11 @@ func (as *AssetDB) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*
return as.Repo.FindEntitiesByType(atype, since)
}

// Link creates an edge between two entities in the database.
// CreateEdge creates an edge between two entities in the database.
// The link is established by creating a new Edge in the database, linking the two entities.
// Returns the created edge as a types.Edge or an error if the link creation fails.
func (as *AssetDB) Link(edge *types.Edge) (*types.Edge, error) {
return as.Repo.Link(edge)
func (as *AssetDB) CreateEdge(edge *types.Edge) (*types.Edge, error) {
return as.Repo.CreateEdge(edge)
}

// IncomingEdges finds all edges pointing to the entity for the specified labels, if any.
Expand All @@ -90,12 +85,12 @@ func (as *AssetDB) OutgoingEdges(entity *types.Entity, since time.Time, labels .
return as.Repo.OutgoingEdges(entity, since, labels...)
}

// CreateEntityTag creates a new entity tag in the database.
// CreateEntityProperty creates a new entity tag in the database.
// It takes an oam.Property as input and persists it in the database.
// The entity tag is serialized to JSON and stored in the Content field of the EntityTag struct.
// Returns the created entity tag as a types.EntityTag or an error if the creation fails.
func (as *AssetDB) CreateEntityTag(entity *types.Entity, property oam.Property) (*types.EntityTag, error) {
return as.Repo.CreateEntityTag(entity, property)
func (as *AssetDB) CreateEntityProperty(entity *types.Entity, property oam.Property) (*types.EntityTag, error) {
return as.Repo.CreateEntityProperty(entity, property)
}

// GetEntityTags finds all tags for the entity with the specified names and last seen after the since parameter.
Expand All @@ -112,12 +107,12 @@ func (as *AssetDB) DeleteEntityTag(id string) error {
return as.Repo.DeleteEntityTag(id)
}

// CreateEdgeTag creates a new edge tag in the database.
// CreateEdgeProperty creates a new edge tag in the database.
// It takes an oam.Property as input and persists it in the database.
// The edge tag is serialized to JSON and stored in the Content field of the EdgeTag struct.
// Returns the created edge tag as a types.EdgeTag or an error if the creation fails.
func (as *AssetDB) CreateEdgeTag(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) {
return as.Repo.CreateEdgeTag(edge, property)
func (as *AssetDB) CreateEdgeProperty(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) {
return as.Repo.CreateEdgeProperty(edge, property)
}

// GetEdgeTags finds all tags for the edge with the specified names and last seen after the since parameter.
Expand Down
30 changes: 20 additions & 10 deletions assetdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestAssetDB(t *testing.T) {
}

if tc.expectedError == nil {
mockAssetDB.On("CreateEntity", tc.discovered).Return(tc.expected, tc.expectedError)
mockAssetDB.On("CreateAsset", tc.discovered).Return(tc.expected, tc.expectedError)
}

e := &types.Edge{
Expand All @@ -74,7 +74,7 @@ func TestAssetDB(t *testing.T) {
}

if tc.source != nil && tc.relation != "" {
mockAssetDB.On("Link", e).Return(&types.Edge{}, nil)
mockAssetDB.On("CreateEdge", e).Return(&types.Edge{}, nil)
}

result, err := adb.Create(e, tc.discovered)
Expand Down Expand Up @@ -384,14 +384,14 @@ func (m *mockAssetDB) GetDBType() string {
return args.String(0)
}

func (m *mockAssetDB) CreateEntity(asset oam.Asset) (*types.Entity, error) {
args := m.Called(asset)
func (m *mockAssetDB) CreateEntity(entity *types.Entity) (*types.Entity, error) {
args := m.Called(entity)
return args.Get(0).(*types.Entity), args.Error(1)
}

func (m *mockAssetDB) UpdateEntityLastSeen(id string) error {
args := m.Called(id)
return args.Error(0)
func (m *mockAssetDB) CreateAsset(asset oam.Asset) (*types.Entity, error) {
args := m.Called(asset)
return args.Get(0).(*types.Entity), args.Error(1)
}

func (m *mockAssetDB) FindEntityById(id string) (*types.Entity, error) {
Expand All @@ -414,7 +414,7 @@ func (m *mockAssetDB) DeleteEntity(id string) error {
return args.Error(0)
}

func (m *mockAssetDB) Link(edge *types.Edge) (*types.Edge, error) {
func (m *mockAssetDB) CreateEdge(edge *types.Edge) (*types.Edge, error) {
args := m.Called(edge)
return args.Get(0).(*types.Edge), args.Error(1)
}
Expand All @@ -439,7 +439,12 @@ func (m *mockAssetDB) DeleteEdge(id string) error {
return args.Error(0)
}

func (m *mockAssetDB) CreateEntityTag(entity *types.Entity, property oam.Property) (*types.EntityTag, error) {
func (m *mockAssetDB) CreateEntityTag(entity *types.Entity, tag *types.EntityTag) (*types.EntityTag, error) {
args := m.Called(entity, tag)
return args.Get(0).(*types.EntityTag), args.Error(1)
}

func (m *mockAssetDB) CreateEntityProperty(entity *types.Entity, property oam.Property) (*types.EntityTag, error) {
args := m.Called(entity, property)
return args.Get(0).(*types.EntityTag), args.Error(1)
}
Expand All @@ -459,7 +464,12 @@ func (m *mockAssetDB) DeleteEntityTag(id string) error {
return args.Error(0)
}

func (m *mockAssetDB) CreateEdgeTag(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) {
func (m *mockAssetDB) CreateEdgeTag(edge *types.Edge, tag *types.EdgeTag) (*types.EdgeTag, error) {
args := m.Called(edge, tag)
return args.Get(0).(*types.EdgeTag), args.Error(1)
}

func (m *mockAssetDB) CreateEdgeProperty(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) {
args := m.Called(edge, property)
return args.Get(0).(*types.EdgeTag), args.Error(1)
}
Expand Down
36 changes: 12 additions & 24 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,35 @@
package cache

import (
"errors"
"sync"
"time"

"github.com/caffix/queue"
assetdb "github.com/owasp-amass/asset-db"
"github.com/owasp-amass/asset-db/repository"
"github.com/owasp-amass/asset-db/repository/sqlrepo"
)

type Cache struct {
sync.Mutex
start time.Time
freq time.Duration
done chan struct{}
cdone chan struct{}
cache repository.Repository
db repository.Repository
queue queue.Queue
}

func New(database repository.Repository, done chan struct{}) (*Cache, error) {
if db := assetdb.New(sqlrepo.SQLiteMemory, ""); db != nil {
c := &Cache{
start: time.Now(),
freq: 10 * time.Minute,
done: done,
cdone: make(chan struct{}, 1),
cache: db.Repo,
db: database,
queue: queue.NewQueue(),
}

go c.processDBCallbacks()
return c, nil
func New(cache, database repository.Repository) (*Cache, error) {
c := &Cache{
start: time.Now(),
freq: 10 * time.Minute,
done: make(chan struct{}, 1),
cache: cache,
db: database,
queue: queue.NewQueue(),
}
return nil, errors.New("failed to create the cache repository")

go c.processDBCallbacks()
return c, nil
}

// StartTime returns the time that the cache was created.
Expand All @@ -60,7 +52,7 @@ func (c *Cache) Close() error {
}
}

close(c.cdone)
close(c.done)
for {
if c.queue.Empty() {
break
Expand All @@ -79,8 +71,6 @@ func (c *Cache) appendToDBQueue(callback func()) {
select {
case <-c.done:
return
case <-c.cdone:
return
default:
}
c.queue.Append(callback)
Expand All @@ -92,8 +82,6 @@ loop:
select {
case <-c.done:
break loop
case <-c.cdone:
break loop
case <-c.queue.Signal():
element, ok := c.queue.Next()

Expand Down
12 changes: 6 additions & 6 deletions cache/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ import (
"github.com/owasp-amass/asset-db/types"
)

// Link implements the Repository interface.
func (c *Cache) Link(edge *types.Edge) (*types.Edge, error) {
// CreateEdge implements the Repository interface.
func (c *Cache) CreateEdge(edge *types.Edge) (*types.Edge, error) {
c.Lock()
defer c.Unlock()

e, err := c.cache.Link(edge)
e, err := c.cache.CreateEdge(edge)
if err != nil {
return nil, err
}

c.appendToDBQueue(func() {
_, _ = c.db.Link(edge)
_, _ = c.db.CreateEdge(edge)
})

return e, nil
Expand Down Expand Up @@ -80,14 +80,14 @@ func (c *Cache) DeleteEdge(id string) error {
return
}

edges, err := c.db.OutgoingEdges(s, time.Time{}, edge.Relation.Label())
edges, err := c.db.OutgoingEdges(s[0], time.Time{}, edge.Relation.Label())
if err != nil || len(edges) == 0 {
return
}

var target *types.Edge
for _, e := range edges {
if e.ID == o.ID && reflect.DeepEqual(e.Relation, o.Relation) {
if e.ID == o[0].ID && reflect.DeepEqual(e.Relation, edge.Relation) {
target = e
break
}
Expand Down
50 changes: 31 additions & 19 deletions cache/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,55 @@ import (
)

// CreateEntity implements the Repository interface.
func (c *Cache) CreateEntity(asset oam.Asset) (*types.Entity, error) {
func (c *Cache) CreateEntity(input *types.Entity) (*types.Entity, error) {
c.Lock()
defer c.Unlock()

entity, err := c.cache.CreateEntity(asset)
entity, err := c.cache.CreateEntity(input)
if err != nil {
return nil, err
}

if tag, found := c.checkCacheEntityTag(entity, "cache_create_entity"); !found {
if last, err := time.Parse("2006-01-02 15:04:05", tag.Value()); err == nil && time.Now().Add(-1*c.freq).After(last) {
if last, err := time.Parse("2006-01-02 15:04:05", tag.Property.Value()); err == nil && time.Now().Add(-1*c.freq).After(last) {
_ = c.cache.DeleteEntityTag(tag.ID)
_ = c.createCacheEntityTag(entity, "cache_create_entity")

c.appendToDBQueue(func() {
_, _ = c.db.CreateEntity(asset)
_, _ = c.db.CreateEntity(&types.Entity{
CreatedAt: input.CreatedAt,
LastSeen: input.LastSeen,
Asset: input.Asset,
})
})
}
}

return entity, nil
}

// UpdateEntityLastSeen implements the Repository interface.
func (c *Cache) UpdateEntityLastSeen(id string) error {
// CreateAsset implements the Repository interface.
func (c *Cache) CreateAsset(asset oam.Asset) (*types.Entity, error) {
c.Lock()
defer c.Unlock()

err := c.cache.UpdateEntityLastSeen(id)
entity, err := c.cache.CreateAsset(asset)
if err != nil {
return err
return nil, err
}

entity, err := c.cache.FindEntityById(id)
if err != nil {
return nil
}
if tag, found := c.checkCacheEntityTag(entity, "cache_create_asset"); !found {
if last, err := time.Parse("2006-01-02 15:04:05", tag.Property.Value()); err == nil && time.Now().Add(-1*c.freq).After(last) {
_ = c.cache.DeleteEntityTag(tag.ID)
_ = c.createCacheEntityTag(entity, "cache_create_asset")

c.appendToDBQueue(func() {
if e, err := c.db.FindEntityByContent(entity.Asset, time.Time{}); err == nil && len(e) == 1 {
_ = c.db.UpdateEntityLastSeen(e[0].ID)
c.appendToDBQueue(func() {
_, _ = c.db.CreateAsset(asset)
})
}
})
}

return nil
return entity, nil
}

// FindEntityById implements the Repository interface.
Expand Down Expand Up @@ -103,7 +107,11 @@ func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.

var results []*types.Entity
for _, entity := range dbentities {
if e, err := c.cache.CreateEntity(entity.Asset); err == nil {
if e, err := c.cache.CreateEntity(&types.Entity{
CreatedAt: entity.CreatedAt,
LastSeen: entity.LastSeen,
Asset: entity.Asset,
}); err == nil {
results = append(results, e)
if tags, err := c.cache.GetEntityTags(entity, c.start, "cache_find_entity_by_content"); err == nil && len(tags) > 0 {
for _, tag := range tags {
Expand Down Expand Up @@ -152,7 +160,11 @@ func (c *Cache) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*typ

var results []*types.Entity
for _, entity := range dbentities {
if e, err := c.cache.CreateEntity(entity.Asset); err == nil {
if e, err := c.cache.CreateEntity(&types.Entity{
CreatedAt: entity.CreatedAt,
LastSeen: entity.LastSeen,
Asset: entity.Asset,
}); err == nil {
results = append(results, e)
if tags, err := c.cache.GetEntityTags(entity, c.start, "cache_find_entities_by_type"); err == nil && len(tags) > 0 {
for _, tag := range tags {
Expand Down
Loading

0 comments on commit fa1e376

Please sign in to comment.