From 8cdb8247a349775185acc32ed3eaf022e70c735c Mon Sep 17 00:00:00 2001 From: caffix Date: Mon, 2 Dec 2024 20:30:07 -0500 Subject: [PATCH] initial implementation of tag search by content methods --- assetdb_test.go | 10 ++++ repository/repository.go | 4 +- repository/sqlrepo/entity.go | 86 +++++++++++++-------------- repository/sqlrepo/models.go | 74 +++++++++++++++++++++++ repository/sqlrepo/tag.go | 110 +++++++++++++++++++++++++++++++++++ 5 files changed, 240 insertions(+), 44 deletions(-) diff --git a/assetdb_test.go b/assetdb_test.go index 9b7cb0c..07859db 100644 --- a/assetdb_test.go +++ b/assetdb_test.go @@ -454,6 +454,11 @@ func (m *mockAssetDB) FindEntityTagById(id string) (*types.EntityTag, error) { return args.Get(0).(*types.EntityTag), args.Error(1) } +func (m *mockAssetDB) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) { + args := m.Called(prop, since) + return args.Get(0).([]*types.EntityTag), args.Error(1) +} + func (m *mockAssetDB) GetEntityTags(entity *types.Entity, since time.Time, names ...string) ([]*types.EntityTag, error) { args := m.Called(entity, since, names) return args.Get(0).([]*types.EntityTag), args.Error(1) @@ -479,6 +484,11 @@ func (m *mockAssetDB) FindEdgeTagById(id string) (*types.EdgeTag, error) { return args.Get(0).(*types.EdgeTag), args.Error(1) } +func (m *mockAssetDB) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) { + args := m.Called(prop, since) + return args.Get(0).([]*types.EdgeTag), args.Error(1) +} + func (m *mockAssetDB) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error) { args := m.Called(edge, since, names) return args.Get(0).([]*types.EdgeTag), args.Error(1) diff --git a/repository/repository.go b/repository/repository.go index 953a11e..710dafb 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -20,10 +20,10 @@ type Repository interface { GetDBType() string CreateEntity(entity *types.Entity) (*types.Entity, error) CreateAsset(asset oam.Asset) (*types.Entity, error) - DeleteEntity(id string) error FindEntityById(id string) (*types.Entity, error) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) + DeleteEntity(id string) error CreateEdge(edge *types.Edge) (*types.Edge, error) FindEdgeById(id string) (*types.Edge, error) IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) @@ -32,11 +32,13 @@ type Repository interface { CreateEntityTag(entity *types.Entity, tag *types.EntityTag) (*types.EntityTag, error) CreateEntityProperty(entity *types.Entity, property oam.Property) (*types.EntityTag, error) FindEntityTagById(id string) (*types.EntityTag, error) + FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) GetEntityTags(entity *types.Entity, since time.Time, names ...string) ([]*types.EntityTag, error) DeleteEntityTag(id string) error CreateEdgeTag(edge *types.Edge, tag *types.EdgeTag) (*types.EdgeTag, error) CreateEdgeProperty(edge *types.Edge, property oam.Property) (*types.EdgeTag, error) FindEdgeTagById(id string) (*types.EdgeTag, error) + FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) GetEdgeTags(edge *types.Edge, since time.Time, names ...string) ([]*types.EdgeTag, error) DeleteEdgeTag(id string) error Close() error diff --git a/repository/sqlrepo/entity.go b/repository/sqlrepo/entity.go index f2524ac..eefa4ab 100644 --- a/repository/sqlrepo/entity.go +++ b/repository/sqlrepo/entity.go @@ -84,21 +84,35 @@ func (sql *sqlRepository) UpdateEntityLastSeen(id string) error { return nil } -// DeleteEntity removes an entity in the database by its ID. -// It takes a string representing the entity ID and removes the corresponding entity from the database. -// Returns an error if the entity is not found. -func (sql *sqlRepository) DeleteEntity(id string) error { +// FindEntityById finds an entity in the database by the ID. +// It takes a string representing the entity ID and retrieves the corresponding entity from the database. +// Returns the found entity as a types.Entity or an error if the asset is not found. +func (sql *sqlRepository) FindEntityById(id string) (*types.Entity, error) { entityId, err := strconv.ParseUint(id, 10, 64) if err != nil { - return err + return nil, err } entity := Entity{ID: entityId} - result := sql.db.Delete(&entity) - return result.Error + result := sql.db.First(&entity) + if err := result.Error; err != nil { + return nil, err + } + + assetData, err := entity.Parse() + if err != nil { + return nil, err + } + + return &types.Entity{ + ID: strconv.FormatUint(entity.ID, 10), + CreatedAt: entity.CreatedAt.In(time.UTC).Local(), + LastSeen: entity.UpdatedAt.In(time.UTC).Local(), + Asset: assetData, + }, nil } -// FindEntityByContent finds entity in the database that match the provided asset data and last seen after the since parameter. +// FindEntityByContent finds entities in the database that match the provided asset data and last seen after the since parameter. // It takes an oam.Asset as input and searches for entities with matching content in the database. // If since.IsZero(), the parameter will be ignored. // The asset data is serialized to JSON and compared against the Content field of the Entity struct. @@ -119,14 +133,14 @@ func (sql *sqlRepository) FindEntityByContent(assetData oam.Asset, since time.Ti return nil, err } - var entities []Entity - var result *gorm.DB - if since.IsZero() { - result = sql.db.Where("etype = ?", entity.Type).Find(&entities, jsonQuery) - } else { - result = sql.db.Where("etype = ? AND updated_at >= ?", entity.Type, since.UTC()).Find(&entities, jsonQuery) + tx := sql.db.Where("etype = ?", entity.Type) + if !since.IsZero() { + tx = tx.Where("updated_at >= ?", since.UTC()) } - if err := result.Error; err != nil { + + var entities []Entity + tx = tx.Where(jsonQuery).Find(&entities) + if err := tx.Error; err != nil { return nil, err } @@ -148,34 +162,6 @@ func (sql *sqlRepository) FindEntityByContent(assetData oam.Asset, since time.Ti return results, nil } -// FindEntityById finds an entity in the database by the ID. -// It takes a string representing the entity ID and retrieves the corresponding entity from the database. -// Returns the found entity as a types.Entity or an error if the asset is not found. -func (sql *sqlRepository) FindEntityById(id string) (*types.Entity, error) { - entityId, err := strconv.ParseUint(id, 10, 64) - if err != nil { - return nil, err - } - - entity := Entity{ID: entityId} - result := sql.db.First(&entity) - if err := result.Error; err != nil { - return nil, err - } - - assetData, err := entity.Parse() - if err != nil { - return nil, err - } - - return &types.Entity{ - ID: strconv.FormatUint(entity.ID, 10), - CreatedAt: entity.CreatedAt.In(time.UTC).Local(), - LastSeen: entity.UpdatedAt.In(time.UTC).Local(), - Asset: assetData, - }, nil -} - // FindEntitiesByType finds all entities in the database of the provided asset type and last seen after the since parameter. // It takes an asset type and retrieves the corresponding entities from the database. // If since.IsZero(), the parameter will be ignored. @@ -210,3 +196,17 @@ func (sql *sqlRepository) FindEntitiesByType(atype oam.AssetType, since time.Tim } return results, nil } + +// DeleteEntity removes an entity in the database by its ID. +// It takes a string representing the entity ID and removes the corresponding entity from the database. +// Returns an error if the entity is not found. +func (sql *sqlRepository) DeleteEntity(id string) error { + entityId, err := strconv.ParseUint(id, 10, 64) + if err != nil { + return err + } + + entity := Entity{ID: entityId} + result := sql.db.Delete(&entity) + return result.Error +} diff --git a/repository/sqlrepo/models.go b/repository/sqlrepo/models.go index 422738b..f202355 100644 --- a/repository/sqlrepo/models.go +++ b/repository/sqlrepo/models.go @@ -292,3 +292,77 @@ func parseProperty(ptype string, content datatypes.JSON) (oam.Property, error) { return prop, err } + +// NameJSONQuery generates the JSON query for the field returned by the Property Name method. +// It returns the parsed property and an error, if any. +func (e *EntityTag) NameJSONQuery() (*datatypes.JSONQueryExpression, error) { + prop, err := e.Parse() + if err != nil { + return nil, err + } + + return propertyNameJSONQuery(prop) +} + +// NameJSONQuery generates the JSON query for the field returned by the Property Name method. +// It returns the parsed property and an error, if any. +func (e *EdgeTag) NameJSONQuery() (*datatypes.JSONQueryExpression, error) { + prop, err := e.Parse() + if err != nil { + return nil, err + } + + return propertyNameJSONQuery(prop) +} + +func propertyNameJSONQuery(prop oam.Property) (*datatypes.JSONQueryExpression, error) { + jsonQuery := datatypes.JSONQuery("content") + + switch v := prop.(type) { + case *property.SimpleProperty: + return jsonQuery.Equals(v.PropertyName, "property_name"), nil + case *property.SourceProperty: + return jsonQuery.Equals(v.Source, "name"), nil + case *property.VulnProperty: + return jsonQuery.Equals(v.ID, "id"), nil + } + + return nil, fmt.Errorf("unknown property type: %s", prop.PropertyType()) +} + +// ValueJSONQuery generates the JSON query for the field returned by the Property Value method. +// It returns the parsed property and an error, if any. +func (e *EntityTag) ValueJSONQuery() (*datatypes.JSONQueryExpression, error) { + prop, err := e.Parse() + if err != nil { + return nil, err + } + + return propertyValueJSONQuery(prop) +} + +// ValueJSONQuery generates the JSON query for the field returned by the Property Value method. +// It returns the parsed property and an error, if any. +func (e *EdgeTag) ValueJSONQuery() (*datatypes.JSONQueryExpression, error) { + prop, err := e.Parse() + if err != nil { + return nil, err + } + + return propertyValueJSONQuery(prop) +} + +func propertyValueJSONQuery(prop oam.Property) (*datatypes.JSONQueryExpression, error) { + jsonQuery := datatypes.JSONQuery("content") + + switch v := prop.(type) { + case *property.SimpleProperty: + return jsonQuery.Equals(v.PropertyValue, "property_value"), nil + case *property.SourceProperty: + return jsonQuery.Equals(v.Confidence, "confidence"), nil + case *property.VulnProperty: + return jsonQuery.Equals(v.Description, "desc"), nil + } + + return nil, fmt.Errorf("unknown property type: %s", prop.PropertyType()) +} diff --git a/repository/sqlrepo/tag.go b/repository/sqlrepo/tag.go index de086e0..6616958 100644 --- a/repository/sqlrepo/tag.go +++ b/repository/sqlrepo/tag.go @@ -121,6 +121,61 @@ func (sql *sqlRepository) FindEntityTagById(id string) (*types.EntityTag, error) }, nil } +// FindEntityTagsByContent finds entity tags in the database that match the provided property data and updated_at after the since parameter. +// It takes an oam.Property as input and searches for entity tags with matching content in the database. +// If since.IsZero(), the parameter will be ignored. +// The property data is serialized to JSON and compared against the Content field of the EntityTag struct. +// Returns a slice of matching entity tags as []*types.EntityTag or an error if the search fails. +func (sql *sqlRepository) FindEntityTagsByContent(prop oam.Property, since time.Time) ([]*types.EntityTag, error) { + jsonContent, err := prop.JSON() + if err != nil { + return nil, err + } + + tag := EntityTag{ + Type: string(prop.PropertyType()), + Content: jsonContent, + } + + nameQuery, err := tag.NameJSONQuery() + if err != nil { + return nil, err + } + + valueQuery, err := tag.ValueJSONQuery() + if err != nil { + return nil, err + } + + tx := sql.db.Where("ttype = ?", tag.Type) + if !since.IsZero() { + tx = tx.Where("updated_at >= ?", since.UTC()) + } + + var tags []EntityTag + tx = tx.Where(nameQuery).Where(valueQuery).Find(&tags) + if err := tx.Error; err != nil { + return nil, err + } + + var results []*types.EntityTag + for _, t := range tags { + if propData, err := t.Parse(); err == nil { + results = append(results, &types.EntityTag{ + ID: strconv.FormatUint(t.ID, 10), + CreatedAt: t.CreatedAt.In(time.UTC).Local(), + LastSeen: t.UpdatedAt.In(time.UTC).Local(), + Property: propData, + }) + } + } + + if len(results) == 0 { + return nil, errors.New("zero entity tags found") + } + return results, nil +} + // GetEntityTags finds all tags for the entity with the specified names and last seen after the since parameter. // If since.IsZero(), the parameter will be ignored. // If no names are specified, all tags for the specified entity are returned. @@ -307,6 +362,61 @@ func (sql *sqlRepository) FindEdgeTagById(id string) (*types.EdgeTag, error) { }, nil } +// FindEdgeTagsByContent finds edge tags in the database that match the provided property data and updated_at after the since parameter. +// It takes an oam.Property as input and searches for edge tags with matching content in the database. +// If since.IsZero(), the parameter will be ignored. +// The property data is serialized to JSON and compared against the Content field of the EdgeTag struct. +// Returns a slice of matching edge tags as []*types.EdgeTag or an error if the search fails. +func (sql *sqlRepository) FindEdgeTagsByContent(prop oam.Property, since time.Time) ([]*types.EdgeTag, error) { + jsonContent, err := prop.JSON() + if err != nil { + return nil, err + } + + tag := EdgeTag{ + Type: string(prop.PropertyType()), + Content: jsonContent, + } + + nameQuery, err := tag.NameJSONQuery() + if err != nil { + return nil, err + } + + valueQuery, err := tag.ValueJSONQuery() + if err != nil { + return nil, err + } + + tx := sql.db.Where("ttype = ?", tag.Type) + if !since.IsZero() { + tx = tx.Where("updated_at >= ?", since.UTC()) + } + + var tags []EdgeTag + tx = tx.Where(nameQuery).Where(valueQuery).Find(&tags) + if err := tx.Error; err != nil { + return nil, err + } + + var results []*types.EdgeTag + for _, t := range tags { + if propData, err := t.Parse(); err == nil { + results = append(results, &types.EdgeTag{ + ID: strconv.FormatUint(t.ID, 10), + CreatedAt: t.CreatedAt.In(time.UTC).Local(), + LastSeen: t.UpdatedAt.In(time.UTC).Local(), + Property: propData, + }) + } + } + + if len(results) == 0 { + return nil, errors.New("zero edge tags found") + } + return results, nil +} + // GetEdgeTags finds all tags for the edge with the specified names and last seen after the since parameter. // If since.IsZero(), the parameter will be ignored. // If no names are specified, all tags for the specified edge are returned.