Skip to content

Commit

Permalink
neo4j entity method unit tests complete
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Dec 18, 2024
1 parent e3058c1 commit e1d9eaf
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 73 deletions.
7 changes: 4 additions & 3 deletions repository/neo4j/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ import (
var store *neoRepository

func TestMain(m *testing.M) {
var err error
dsn := "bolt://neo4j:hackme4fun@localhost:7687/amass"

store, err := New("neo4j", dsn)
store, err = New("neo4j", dsn)
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}
defer store.Close()

if err := neomigrations.InitializeSchema(store.db, store.dbname); err != nil {
fmt.Println(err)
return
os.Exit(1)
}

os.Exit(m.Run())
Expand Down
26 changes: 9 additions & 17 deletions repository/neo4j/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,10 @@ func (neo *neoRepository) CreateEntity(input *types.Entity) (*types.Entity, erro
input.ID = neo.uniqueEntityID()
}
if input.CreatedAt.IsZero() {
entity.CreatedAt = time.Now()
} else {
entity.CreatedAt = input.CreatedAt
input.CreatedAt = time.Now()
}

if input.LastSeen.IsZero() {
entity.LastSeen = time.Now()
} else {
entity.LastSeen = input.LastSeen
input.LastSeen = time.Now()
}

props, err := entityPropsMap(input)
Expand All @@ -96,12 +91,9 @@ func (neo *neoRepository) CreateEntity(input *types.Entity) (*types.Entity, erro
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db,
"CREATE (a:$($labels) $props) RETURN a",
map[string]interface{}{
"labels": []string{"Entity", string(input.Asset.AssetType())},
"props": props,
},
query := fmt.Sprintf("CREATE (a:Entity:%s $props) RETURN a", input.Asset.AssetType())
result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query,
map[string]interface{}{"props": props},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
Expand Down Expand Up @@ -156,8 +148,8 @@ func (neo *neoRepository) FindEntityById(id string) (*types.Entity, error) {
defer cancel()

result, err := neo4jdb.ExecuteQuery(ctx, neo.db,
"MATCH (a:Entity {entity_id: $entity_id}) RETURN a",
map[string]interface{}{"entity_id": id},
"MATCH (a:Entity {entity_id: $eid}) RETURN a",
map[string]interface{}{"eid": id},
neo4jdb.EagerResultTransformer,
neo4jdb.ExecuteQueryWithDatabase(neo.dbname),
)
Expand Down Expand Up @@ -191,7 +183,7 @@ func (neo *neoRepository) FindEntitiesByContent(assetData oam.Asset, since time.

query := "MATCH " + qnode + " RETURN a"
if !since.IsZero() {
query = fmt.Sprintf("MATCH %s WHERE a.updated_at >= '%s' RETURN a", qnode, timeToNeo4jTime(since))
query = fmt.Sprintf("MATCH %s WHERE a.updated_at >= localDateTime('%s') RETURN a", qnode, timeToNeo4jTime(since))
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down Expand Up @@ -230,7 +222,7 @@ func (neo *neoRepository) FindEntitiesByContent(assetData oam.Asset, since time.
func (neo *neoRepository) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) {
query := fmt.Sprintf("MATCH (a:%s) RETURN a", string(atype))
if !since.IsZero() {
query = fmt.Sprintf("MATCH (a:%s) WHERE a.updated_at >= '%s' RETURN a", string(atype), timeToNeo4jTime(since))
query = fmt.Sprintf("MATCH (a:%s) WHERE a.updated_at >= localDateTime('%s') RETURN a", string(atype), timeToNeo4jTime(since))
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down
122 changes: 113 additions & 9 deletions repository/neo4j/entity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,35 @@
package neo4j

import (
"fmt"
"net/netip"
"testing"
"time"

"github.com/owasp-amass/asset-db/types"
oam "github.com/owasp-amass/open-asset-model"
"github.com/owasp-amass/open-asset-model/domain"
oamnet "github.com/owasp-amass/open-asset-model/network"
"github.com/owasp-amass/open-asset-model/org"
"github.com/stretchr/testify/assert"
)

func TestCreateEntity(t *testing.T) {
entity, err := store.CreateEntity(&types.Entity{
Asset: &domain.FQDN{
Name: "test.com",
Name: "create1.entity",
},
})
if err != nil {
t.Errorf("Failed to create the entity: %v", err)
}
assert.NoError(t, err)

time.Sleep(500 * time.Millisecond)
time.Sleep(250 * time.Millisecond)
newer, err := store.CreateEntity(&types.Entity{
Asset: &domain.FQDN{
Name: "test.com",
Name: "create1.entity",
},
})
if err != nil {
t.Errorf("Failed to create the newer entity: %v", err)
}
assert.NoError(t, err)

if entity.ID != newer.ID || entity.Asset.AssetType() != newer.Asset.AssetType() {
t.Errorf("Failed to prevent duplicate entities from being created")
}
Expand All @@ -49,4 +52,105 @@ func TestCreateEntity(t *testing.T) {
if !entity.LastSeen.Before(newer.LastSeen) {
t.Errorf("Failed to update the LastSeen timestamp")
}

time.Sleep(250 * time.Millisecond)
second, err := store.CreateEntity(&types.Entity{
Asset: &domain.FQDN{
Name: "create2.entity",
},
})
assert.NoError(t, err)

if newer.ID == second.ID {
t.Errorf("Failed to create an unique entity_id for the second entity")
}
if !second.CreatedAt.After(newer.LastSeen) {
t.Errorf("Failed to assign the second entity an accurate creation time")
}
}

func TestFindEntityById(t *testing.T) {
entity, err := store.CreateEntity(&types.Entity{
Asset: &domain.FQDN{
Name: "find1.entity",
},
})
assert.NoError(t, err)

same, err := store.FindEntityById(entity.ID)
if err != nil {
t.Errorf("Failed to find the entity: %v", err)
}
if entity.ID != same.ID {
t.Errorf("Failed to return an entity with the correct ID")
}
if fqdn1, ok := entity.Asset.(*domain.FQDN); !ok {
t.Errorf("Failed to type assert the first asset")
} else if fqdn2, ok := same.Asset.(*domain.FQDN); !ok {
t.Errorf("Failed to type assert the second asset")
} else if fqdn1.Name != fqdn2.Name {
t.Errorf("Failed to return an entity with the correct name")
}
}

func TestFindEntitiesByContent(t *testing.T) {
fqdn := &domain.FQDN{Name: "findcontent.entity"}

_, err := store.FindEntitiesByContent(fqdn, time.Time{})
assert.Error(t, err)

entity, err := store.CreateAsset(fqdn)
assert.NoError(t, err)

e, err := store.FindEntitiesByContent(fqdn, entity.CreatedAt.Add(-1*time.Second))
assert.NoError(t, err)
same := e[0]

if entity.ID != same.ID {
t.Errorf("Failed to return an entity with the correct ID")
}
if fqdn1, ok := entity.Asset.(*domain.FQDN); !ok {
t.Errorf("Failed to type assert the first asset")
} else if fqdn2, ok := same.Asset.(*domain.FQDN); !ok {
t.Errorf("Failed to type assert the second asset")
} else if fqdn1.Name != fqdn2.Name {
t.Errorf("Failed to return an entity with the correct name")
}

_, err = store.FindEntitiesByContent(fqdn, entity.CreatedAt.Add(250*time.Millisecond))
assert.Error(t, err)
}

func TestFindEntitiesByType(t *testing.T) {
now := time.Now()

for i := 1; i <= 10; i++ {
addr, err := netip.ParseAddr(fmt.Sprintf("192.168.1.%d", i))
assert.NoError(t, err)

_, err = store.CreateAsset(&oamnet.IPAddress{
Address: addr,
Type: "IPv4",
})
assert.NoError(t, err)
}

entities, err := store.FindEntitiesByType(oam.IPAddress, time.Time{})
assert.NoError(t, err)

if len(entities) < 10 {
t.Errorf("Failed to return the correct number of entities")
}

for i := 1; i <= 10; i++ {
_, err := store.CreateAsset(&org.Organization{Name: fmt.Sprintf("findtype%d.entity", i)})
assert.NoError(t, err)
}

entities, err = store.FindEntitiesByType(oam.Organization, now)
assert.NoError(t, err)

if len(entities) < 10 {
t.Errorf("Failed to return the correct number of entities")
}
}
5 changes: 2 additions & 3 deletions repository/neo4j/extract_entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package neo4j
import (
"errors"
"net/netip"
"time"

neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/owasp-amass/asset-db/types"
Expand All @@ -34,13 +33,13 @@ func nodeToEntity(node neo4jdb.Node) (*types.Entity, error) {
if err != nil {
return nil, err
}
created := t.Time().In(time.UTC).Local()
created := neo4jTimeToTime(t)

t, err = neo4jdb.GetProperty[neo4jdb.LocalDateTime](node, "updated_at")
if err != nil {
return nil, err
}
updated := t.Time().In(time.UTC).Local()
updated := neo4jTimeToTime(t)

etype, err := neo4jdb.GetProperty[string](node, "etype")
if err != nil {
Expand Down
23 changes: 23 additions & 0 deletions repository/neo4j/neo4j_time.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package neo4j

import (
"time"

"github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype"
)

func neo4jTimeToTime(neot dbtype.LocalDateTime) time.Time {
t := neot.Time()
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
return t.Local()
}

func timeToNeo4jTime(t time.Time) dbtype.LocalDateTime {
t = t.UTC()
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
return dbtype.LocalDateTime(t)
}
Loading

0 comments on commit e1d9eaf

Please sign in to comment.