Skip to content

Commit

Permalink
implemented most of the Repository interface Entity methods
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Dec 17, 2024
1 parent 6e5f3ea commit e3058c1
Show file tree
Hide file tree
Showing 5 changed files with 615 additions and 171 deletions.
7 changes: 1 addition & 6 deletions repository/neo4j/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"testing"

neomigrations "github.com/owasp-amass/asset-db/migrations/neo4j"
"github.com/stretchr/testify/assert"
)

var store *neoRepository
Expand All @@ -25,6 +24,7 @@ func TestMain(m *testing.M) {
fmt.Println(err)
return
}
defer store.Close()

if err := neomigrations.InitializeSchema(store.db, store.dbname); err != nil {
fmt.Println(err)
Expand All @@ -34,11 +34,6 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestClose(t *testing.T) {
err := store.Close()
assert.NoError(t, err)
}

func TestGetDBType(t *testing.T) {
if db := store.GetDBType(); db != Neo4j {
t.Errorf("Failed to return the correct database type")
Expand Down
258 changes: 165 additions & 93 deletions repository/neo4j/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,128 @@ package neo4j
import (
"context"
"errors"
"strconv"
"fmt"
"time"

"github.com/google/uuid"
neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/owasp-amass/asset-db/types"
"gorm.io/gorm"
oam "github.com/owasp-amass/open-asset-model"
)

// CreateEntity creates a new entity in the database.
// It takes an Entity as input and persists it in the database.
// The asset is serialized to JSON and stored in the Content field of the Entity struct.
// Returns the created entity as a types.Entity or an error if the creation fails.
func (neo *neoRepository) CreateEntity(input *types.Entity) (*types.Entity, error) {
jsonContent, err := input.Asset.JSON()
if err != nil {
return nil, err
}
var entity *types.Entity

entity := Entity{
Type: string(input.Asset.AssetType()),
Content: jsonContent,
if input == nil {
return nil, errors.New("the input entity is nil")
}

// ensure that duplicate entities are not entered into the database
if entities, err := neo.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) == 1 {
e := entities[0]

if input.Asset.AssetType() == e.Asset.AssetType() {
if id, err := strconv.ParseUint(e.ID, 10, 64); err == nil {
entity.ID = id
entity.CreatedAt = e.CreatedAt
entity.UpdatedAt = time.Now().UTC()
}
if input.Asset.AssetType() != e.Asset.AssetType() {
return nil, errors.New("the asset type does not match the existing entity")
}

qnode, err := queryNodeByAssetKey("a", e.Asset)
if err != nil {
return nil, err
}

e.LastSeen = time.Now()
props, err := entityPropsMap(e)
if err != nil {
return nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db,
"MATCH "+qnode+" SET a = $props RETURN a",
map[string]interface{}{"props": props},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}
if len(result.Records) == 0 {
return nil, errors.New("no records returned from the query")
}

node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "a")
if err != nil {
return nil, err
}
if isnil {
return nil, errors.New("the record value for the node is nil")
}

if e, err := nodeToEntity(node); err == nil && e != nil {
entity = e
}
} else {
if input.ID == "" {
input.ID = neo.uniqueEntityID()
}
if input.CreatedAt.IsZero() {
entity.CreatedAt = time.Now().UTC()
entity.CreatedAt = time.Now()
} else {
entity.CreatedAt = input.CreatedAt.UTC()
entity.CreatedAt = input.CreatedAt
}

if input.LastSeen.IsZero() {
entity.UpdatedAt = time.Now().UTC()
entity.LastSeen = time.Now()
} else {
entity.UpdatedAt = input.LastSeen.UTC()
entity.LastSeen = input.LastSeen
}
}

result := sql.db.Save(&entity)
if err := result.Error; err != nil {
return nil, err
props, err := entityPropsMap(input)
if err != nil {
return nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db,
"CREATE (a:$($labels) $props) RETURN a",
map[string]interface{}{
"labels": []string{"Entity", string(input.Asset.AssetType())},
"props": props,
},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}
if len(result.Records) == 0 {
return nil, errors.New("no records returned from the query")
}

node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "a")
if err != nil {
return nil, err
}
if isnil {
return nil, errors.New("the record value for the node is nil")
}

if e, err := nodeToEntity(node); err == nil && e != nil {
entity = e
}
}

return &types.Entity{
ID: strconv.FormatUint(entity.ID, 10),
CreatedAt: entity.CreatedAt.In(time.UTC).Local(),
LastSeen: entity.UpdatedAt.In(time.UTC).Local(),
Asset: input.Asset,
}, nil
if entity == nil {
return nil, errors.New("failed to create the entity")
}
return entity, nil
}

// CreateAsset creates a new entity in the database.
Expand All @@ -76,41 +139,43 @@ func (neo *neoRepository) CreateAsset(asset oam.Asset) (*types.Entity, error) {
return neo.CreateEntity(&types.Entity{Asset: asset})
}

// UpdateEntityLastSeen performs an update on the entity.
func (neo *neoRepository) UpdateEntityLastSeen(id string) error {
result := sql.db.Exec("UPDATE entities SET updated_at = current_timestamp WHERE entity_id = ?", id)
if err := result.Error; err != nil {
return err
func (neo *neoRepository) uniqueEntityID() string {
for {
id := uuid.New().String()
if _, err := neo.FindEntityById(id); err != nil {
return id
}
}
return nil
}

// FindEntityById finds an entity in the database by the ID.
// It takes a string representing the entity ID and retrieves the corresponding entity from the database.
// Returns the found entity as a types.Entity or an error if the asset is not found.
func (neo *neoRepository) FindEntityById(id string) (*types.Entity, error) {
entityId, err := strconv.ParseUint(id, 10, 64)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db,
"MATCH (a:Entity {entity_id: $entity_id}) RETURN a",
map[string]interface{}{"entity_id": id},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}

entity := Entity{ID: entityId}
result := sql.db.First(&entity)
if err := result.Error; err != nil {
return nil, err
if len(result.Records) == 0 {
return nil, fmt.Errorf("the entity with ID %s was not found", id)
}

assetData, err := entity.Parse()
node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "a")
if err != nil {
return nil, err
}

return &types.Entity{
ID: strconv.FormatUint(entity.ID, 10),
CreatedAt: entity.CreatedAt.In(time.UTC).Local(),
LastSeen: entity.UpdatedAt.In(time.UTC).Local(),
Asset: assetData,
}, nil
if isnil {
return nil, errors.New("the record value for the node is nil")
}
return nodeToEntity(node)
}

// FindEntitiesByContent finds entities in the database that match the provided asset data and last seen after
Expand All @@ -119,77 +184,84 @@ func (neo *neoRepository) FindEntityById(id string) (*types.Entity, error) {
// The asset data is serialized to JSON and compared against the Content field of the Entity struct.
// Returns a slice of matching entities as []*types.Entity or an error if the search fails.
func (neo *neoRepository) FindEntitiesByContent(assetData oam.Asset, since time.Time) ([]*types.Entity, error) {
jsonContent, err := assetData.JSON()
qnode, err := queryNodeByAssetKey("a", assetData)
if err != nil {
return nil, err
}

entity := Entity{
Type: string(assetData.AssetType()),
Content: jsonContent,
query := "MATCH " + qnode + " RETURN a"
if !since.IsZero() {
query = fmt.Sprintf("MATCH %s WHERE a.updated_at >= '%s' RETURN a", qnode, timeToNeo4jTime(since))
}

jsonQuery, err := entity.JSONQuery()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil,
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}

tx := sql.db.Where("etype = ?", entity.Type)
if !since.IsZero() {
tx = tx.Where("updated_at >= ?", since.UTC())
if len(result.Records) == 0 {
return nil, errors.New("no entities found")
}

var entities []Entity
tx = tx.Where(jsonQuery).Find(&entities)
if err := tx.Error; err != nil {
node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](result.Records[0], "a")
if err != nil {
return nil, err
}

var results []*types.Entity
for _, e := range entities {
if assetData, err := e.Parse(); err == nil {
results = append(results, &types.Entity{
ID: strconv.FormatUint(e.ID, 10),
CreatedAt: e.CreatedAt.In(time.UTC).Local(),
LastSeen: e.UpdatedAt.In(time.UTC).Local(),
Asset: assetData,
})
}
if isnil {
return nil, errors.New("the record value for the node is nil")
}

if len(results) == 0 {
return nil, errors.New("zero entities found")
e, err := nodeToEntity(node)
if err != nil {
return nil, err
}
return results, nil
return []*types.Entity{e}, nil
}

// FindEntitiesByType finds all entities in the database of the provided asset type and last seen after the since parameter.
// It takes an asset type and retrieves the corresponding entities from the database.
// If since.IsZero(), the parameter will be ignored.
// Returns a slice of matching entities as []*types.Entity or an error if the search fails.
func (neo *neoRepository) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) {
var entities []Entity
var result *gorm.DB

if since.IsZero() {
result = sql.db.Where("etype = ?", atype).Find(&entities)
} else {
result = sql.db.Where("etype = ? AND updated_at >= ?", atype, since.UTC()).Find(&entities)
query := fmt.Sprintf("MATCH (a:%s) RETURN a", string(atype))
if !since.IsZero() {
query = fmt.Sprintf("MATCH (a:%s) WHERE a.updated_at >= '%s' RETURN a", string(atype), timeToNeo4jTime(since))
}
if err := result.Error; err != nil {

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil,
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
if err != nil {
return nil, err
}
if len(result.Records) == 0 {
return nil, errors.New("no entities of the specified type")
}

var results []*types.Entity
for _, e := range entities {
if f, err := e.Parse(); err == nil {
results = append(results, &types.Entity{
ID: strconv.FormatUint(e.ID, 10),
CreatedAt: e.CreatedAt.In(time.UTC).Local(),
LastSeen: e.UpdatedAt.In(time.UTC).Local(),
Asset: f,
})
for _, record := range result.Records {
node, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Node](record, "a")
if err != nil {
return nil, err
}
if isnil {
return nil, errors.New("the record value for the node is nil")
}

e, err := nodeToEntity(node)
if err != nil {
return nil, err
}
results = append(results, e)
}

if len(results) == 0 {
Expand All @@ -202,7 +274,7 @@ func (neo *neoRepository) FindEntitiesByType(atype oam.AssetType, since time.Tim
// It takes a string representing the entity ID and removes the corresponding entity from the database.
// Returns an error if the entity is not found.
func (neo *neoRepository) DeleteEntity(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

_, err := neo4jdb.ExecuteQuery(ctx, neo.db,
Expand Down
Loading

0 comments on commit e3058c1

Please sign in to comment.