Skip to content

Commit

Permalink
initial implementation of tag search by content methods
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Dec 3, 2024
1 parent a937f1f commit 8cdb824
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 44 deletions.
10 changes: 10 additions & 0 deletions assetdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ func (m *mockAssetDB) FindEntityTagById(id string) (*types.EntityTag, error) {
return args.Get(0).(*types.EntityTag), args.Error(1)
}

func (m *mockAssetDB) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) {
args := m.Called(prop, since)
return args.Get(0).([]*types.EntityTag), args.Error(1)
}

func (m *mockAssetDB) GetEntityTags(entity *types.Entity, since time.Time, names ...string) ([]*types.EntityTag, error) {
args := m.Called(entity, since, names)
return args.Get(0).([]*types.EntityTag), args.Error(1)
Expand All @@ -479,6 +484,11 @@ func (m *mockAssetDB) FindEdgeTagById(id string) (*types.EdgeTag, error) {
return args.Get(0).(*types.EdgeTag), args.Error(1)
}

func (m *mockAssetDB) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) {
args := m.Called(prop, since)
return args.Get(0).([]*types.EdgeTag), args.Error(1)
}

func (m *mockAssetDB) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error) {
args := m.Called(edge, since, names)
return args.Get(0).([]*types.EdgeTag), args.Error(1)
Expand Down
4 changes: 3 additions & 1 deletion repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ type Repository interface {
GetDBType() string
CreateEntity(entity *types.Entity) (*types.Entity, error)
CreateAsset(asset oam.Asset) (*types.Entity, error)
DeleteEntity(id string) error
FindEntityById(id string) (*types.Entity, error)
FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error)
FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error)
DeleteEntity(id string) error
CreateEdge(edge *types.Edge) (*types.Edge, error)
FindEdgeById(id string) (*types.Edge, error)
IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error)
Expand All @@ -32,11 +32,13 @@ type Repository interface {
CreateEntityTag(entity *types.Entity, tag *types.EntityTag) (*types.EntityTag, error)
CreateEntityProperty(entity *types.Entity, property oam.Property) (*types.EntityTag, error)
FindEntityTagById(id string) (*types.EntityTag, error)
FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error)
GetEntityTags(entity *types.Entity, since time.Time, names ...string) ([]*types.EntityTag, error)
DeleteEntityTag(id string) error
CreateEdgeTag(edge *types.Edge, tag *types.EdgeTag) (*types.EdgeTag, error)
CreateEdgeProperty(edge *types.Edge, property oam.Property) (*types.EdgeTag, error)
FindEdgeTagById(id string) (*types.EdgeTag, error)
FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error)
GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error)
DeleteEdgeTag(id string) error
Close() error
Expand Down
86 changes: 43 additions & 43 deletions repository/sqlrepo/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,35 @@ func (sql *sqlRepository) UpdateEntityLastSeen(id string) error {
return nil
}

// DeleteEntity removes an entity in the database by its ID.
// It takes a string representing the entity ID and removes the corresponding entity from the database.
// Returns an error if the entity is not found.
func (sql *sqlRepository) DeleteEntity(id string) error {
// FindEntityById finds an entity in the database by the ID.
// It takes a string representing the entity ID and retrieves the corresponding entity from the database.
// Returns the found entity as a types.Entity or an error if the asset is not found.
func (sql *sqlRepository) FindEntityById(id string) (*types.Entity, error) {
entityId, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return err
return nil, err
}

entity := Entity{ID: entityId}
result := sql.db.Delete(&entity)
return result.Error
result := sql.db.First(&entity)
if err := result.Error; err != nil {
return nil, err
}

assetData, err := entity.Parse()
if err != nil {
return nil, err
}

return &types.Entity{
ID: strconv.FormatUint(entity.ID, 10),
CreatedAt: entity.CreatedAt.In(time.UTC).Local(),
LastSeen: entity.UpdatedAt.In(time.UTC).Local(),
Asset: assetData,
}, nil
}

// FindEntityByContent finds entity in the database that match the provided asset data and last seen after the since parameter.
// FindEntityByContent finds entities in the database that match the provided asset data and last seen after the since parameter.
// It takes an oam.Asset as input and searches for entities with matching content in the database.
// If since.IsZero(), the parameter will be ignored.
// The asset data is serialized to JSON and compared against the Content field of the Entity struct.
Expand All @@ -119,14 +133,14 @@ func (sql *sqlRepository) FindEntityByContent(assetData oam.Asset, since time.Ti
return nil, err
}

var entities []Entity
var result *gorm.DB
if since.IsZero() {
result = sql.db.Where("etype = ?", entity.Type).Find(&entities, jsonQuery)
} else {
result = sql.db.Where("etype = ? AND updated_at >= ?", entity.Type, since.UTC()).Find(&entities, jsonQuery)
tx := sql.db.Where("etype = ?", entity.Type)
if !since.IsZero() {
tx = tx.Where("updated_at >= ?", since.UTC())
}
if err := result.Error; err != nil {

var entities []Entity
tx = tx.Where(jsonQuery).Find(&entities)
if err := tx.Error; err != nil {
return nil, err
}

Expand All @@ -148,34 +162,6 @@ func (sql *sqlRepository) FindEntityByContent(assetData oam.Asset, since time.Ti
return results, nil
}

// FindEntityById finds an entity in the database by the ID.
// It takes a string representing the entity ID and retrieves the corresponding entity from the database.
// Returns the found entity as a types.Entity or an error if the asset is not found.
func (sql *sqlRepository) FindEntityById(id string) (*types.Entity, error) {
entityId, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return nil, err
}

entity := Entity{ID: entityId}
result := sql.db.First(&entity)
if err := result.Error; err != nil {
return nil, err
}

assetData, err := entity.Parse()
if err != nil {
return nil, err
}

return &types.Entity{
ID: strconv.FormatUint(entity.ID, 10),
CreatedAt: entity.CreatedAt.In(time.UTC).Local(),
LastSeen: entity.UpdatedAt.In(time.UTC).Local(),
Asset: assetData,
}, nil
}

// FindEntitiesByType finds all entities in the database of the provided asset type and last seen after the since parameter.
// It takes an asset type and retrieves the corresponding entities from the database.
// If since.IsZero(), the parameter will be ignored.
Expand Down Expand Up @@ -210,3 +196,17 @@ func (sql *sqlRepository) FindEntitiesByType(atype oam.AssetType, since time.Tim
}
return results, nil
}

// DeleteEntity removes an entity in the database by its ID.
// It takes a string representing the entity ID and removes the corresponding entity from the database.
// Returns an error if the entity is not found.
func (sql *sqlRepository) DeleteEntity(id string) error {
entityId, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return err
}

entity := Entity{ID: entityId}
result := sql.db.Delete(&entity)
return result.Error
}
74 changes: 74 additions & 0 deletions repository/sqlrepo/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,77 @@ func parseProperty(ptype string, content datatypes.JSON) (oam.Property, error) {

return prop, err
}

// NameJSONQuery generates the JSON query for the field returned by the Property Name method.
// It returns the parsed property and an error, if any.
func (e *EntityTag) NameJSONQuery() (*datatypes.JSONQueryExpression, error) {
prop, err := e.Parse()
if err != nil {
return nil, err
}

return propertyNameJSONQuery(prop)
}

// NameJSONQuery generates the JSON query for the field returned by the Property Name method.
// It returns the parsed property and an error, if any.
func (e *EdgeTag) NameJSONQuery() (*datatypes.JSONQueryExpression, error) {
prop, err := e.Parse()
if err != nil {
return nil, err
}

return propertyNameJSONQuery(prop)
}

func propertyNameJSONQuery(prop oam.Property) (*datatypes.JSONQueryExpression, error) {
jsonQuery := datatypes.JSONQuery("content")

switch v := prop.(type) {
case *property.SimpleProperty:
return jsonQuery.Equals(v.PropertyName, "property_name"), nil
case *property.SourceProperty:
return jsonQuery.Equals(v.Source, "name"), nil
case *property.VulnProperty:
return jsonQuery.Equals(v.ID, "id"), nil
}

return nil, fmt.Errorf("unknown property type: %s", prop.PropertyType())
}

// ValueJSONQuery generates the JSON query for the field returned by the Property Value method.
// It returns the parsed property and an error, if any.
func (e *EntityTag) ValueJSONQuery() (*datatypes.JSONQueryExpression, error) {
prop, err := e.Parse()
if err != nil {
return nil, err
}

return propertyValueJSONQuery(prop)
}

// ValueJSONQuery generates the JSON query for the field returned by the Property Value method.
// It returns the parsed property and an error, if any.
func (e *EdgeTag) ValueJSONQuery() (*datatypes.JSONQueryExpression, error) {
prop, err := e.Parse()
if err != nil {
return nil, err
}

return propertyValueJSONQuery(prop)
}

func propertyValueJSONQuery(prop oam.Property) (*datatypes.JSONQueryExpression, error) {
jsonQuery := datatypes.JSONQuery("content")

switch v := prop.(type) {
case *property.SimpleProperty:
return jsonQuery.Equals(v.PropertyValue, "property_value"), nil
case *property.SourceProperty:
return jsonQuery.Equals(v.Confidence, "confidence"), nil
case *property.VulnProperty:
return jsonQuery.Equals(v.Description, "desc"), nil
}

return nil, fmt.Errorf("unknown property type: %s", prop.PropertyType())
}
110 changes: 110 additions & 0 deletions repository/sqlrepo/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,61 @@ func (sql *sqlRepository) FindEntityTagById(id string) (*types.EntityTag, error)
}, nil
}

// FindEntityTagsByContent finds entity tags in the database that match the provided property data and updated_at after the since parameter.
// It takes an oam.Property as input and searches for entity tags with matching content in the database.
// If since.IsZero(), the parameter will be ignored.
// The property data is serialized to JSON and compared against the Content field of the EntityTag struct.
// Returns a slice of matching entity tags as []*types.EntityTag or an error if the search fails.
func (sql *sqlRepository) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) {
jsonContent, err := prop.JSON()
if err != nil {
return nil, err
}

tag := EntityTag{
Type: string(prop.PropertyType()),
Content: jsonContent,
}

nameQuery, err := tag.NameJSONQuery()
if err != nil {
return nil, err
}

valueQuery, err := tag.ValueJSONQuery()
if err != nil {
return nil, err
}

tx := sql.db.Where("ttype = ?", tag.Type)
if !since.IsZero() {
tx = tx.Where("updated_at >= ?", since.UTC())
}

var tags []EntityTag
tx = tx.Where(nameQuery).Where(valueQuery).Find(&tags)
if err := tx.Error; err != nil {
return nil, err
}

var results []*types.EntityTag
for _, t := range tags {
if propData, err := t.Parse(); err == nil {
results = append(results, &types.EntityTag{
ID: strconv.FormatUint(t.ID, 10),
CreatedAt: t.CreatedAt.In(time.UTC).Local(),
LastSeen: t.UpdatedAt.In(time.UTC).Local(),
Property: propData,
})
}
}

if len(results) == 0 {
return nil, errors.New("zero entity tags found")
}
return results, nil
}

// GetEntityTags finds all tags for the entity with the specified names and last seen after the since parameter.
// If since.IsZero(), the parameter will be ignored.
// If no names are specified, all tags for the specified entity are returned.
Expand Down Expand Up @@ -307,6 +362,61 @@ func (sql *sqlRepository) FindEdgeTagById(id string) (*types.EdgeTag, error) {
}, nil
}

// FindEdgeTagsByContent finds edge tags in the database that match the provided property data and updated_at after the since parameter.
// It takes an oam.Property as input and searches for edge tags with matching content in the database.
// If since.IsZero(), the parameter will be ignored.
// The property data is serialized to JSON and compared against the Content field of the EdgeTag struct.
// Returns a slice of matching edge tags as []*types.EdgeTag or an error if the search fails.
func (sql *sqlRepository) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) {
jsonContent, err := prop.JSON()
if err != nil {
return nil, err
}

tag := EdgeTag{
Type: string(prop.PropertyType()),
Content: jsonContent,
}

nameQuery, err := tag.NameJSONQuery()
if err != nil {
return nil, err
}

valueQuery, err := tag.ValueJSONQuery()
if err != nil {
return nil, err
}

tx := sql.db.Where("ttype = ?", tag.Type)
if !since.IsZero() {
tx = tx.Where("updated_at >= ?", since.UTC())
}

var tags []EdgeTag
tx = tx.Where(nameQuery).Where(valueQuery).Find(&tags)
if err := tx.Error; err != nil {
return nil, err
}

var results []*types.EdgeTag
for _, t := range tags {
if propData, err := t.Parse(); err == nil {
results = append(results, &types.EdgeTag{
ID: strconv.FormatUint(t.ID, 10),
CreatedAt: t.CreatedAt.In(time.UTC).Local(),
LastSeen: t.UpdatedAt.In(time.UTC).Local(),
Property: propData,
})
}
}

if len(results) == 0 {
return nil, errors.New("zero edge tags found")
}
return results, nil
}

// GetEdgeTags finds all tags for the edge with the specified names and last seen after the since parameter.
// If since.IsZero(), the parameter will be ignored.
// If no names are specified, all tags for the specified edge are returned.
Expand Down

0 comments on commit 8cdb824

Please sign in to comment.