Skip to content

Commit

Permalink
Add unit tests and a new raw_update convenience method #102
Browse files Browse the repository at this point in the history
Signed-off-by: tdruez <[email protected]>
  • Loading branch information
tdruez committed Dec 27, 2024
1 parent 551954f commit 4532bc2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
29 changes: 26 additions & 3 deletions dje/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,15 +863,38 @@ def update_from_data(self, user, data, override=False, override_unknown=False):

def update(self, **kwargs):
"""
Update this instance with the provided ``kwargs`` values.
The full ``save()`` process will be triggered, including signals, and the
``update_fields`` is automatically set.
Update this instance with the provided field values.
This method modifies the specified fields on the current instance and triggers
the full ``save()`` lifecycle, including calling signals like ``pre_save`` and
``post_save``.
The ``update_fields`` parameter is automatically set to limit the save
operation to the updated fields.
"""
for field_name, value in kwargs.items():
setattr(self, field_name, value)

self.save(update_fields=list(kwargs.keys()))

def raw_update(self, **kwargs):
"""
Perform a direct SQL UPDATE on this instance.
This method updates the specified fields in the database without triggering
the ``save()`` lifecycle or related signals. It bypasses field validation and
other ORM hooks for improved performance, but requires careful usage to avoid
inconsistent states.
The instance's in-memory attributes are updated to reflect the changes.
"""
updated_rows = self.__class__.objects.filter(pk=self.pk).update(**kwargs)

# Update the instance's attributes in memory
for field_name, value in kwargs.items():
setattr(self, field_name, value)

return updated_rows

def as_json(self):
try:
serialized_data = serialize(
Expand Down
10 changes: 10 additions & 0 deletions product_portfolio/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dje.tests import make_string
from product_portfolio.models import Product
from product_portfolio.models import ProductComponent
from product_portfolio.models import ProductItemPurpose
from product_portfolio.models import ProductPackage


Expand Down Expand Up @@ -64,3 +65,12 @@ def make_product_component(product, component=None):
component=component,
dataspace=dataspace,
)


def make_product_item_purpose(dataspace, **data):
return ProductItemPurpose.objects.create(
label=make_string(10),
text=make_string(10),
dataspace=dataspace,
**data,
)
36 changes: 26 additions & 10 deletions product_portfolio/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
from product_portfolio.models import ProductComponentAssignedLicense
from product_portfolio.models import ProductDependency
from product_portfolio.models import ProductInventoryItem
from product_portfolio.models import ProductItemPurpose
from product_portfolio.models import ProductPackage
from product_portfolio.models import ProductRelationStatus
from product_portfolio.models import ProductSecuredManager
from product_portfolio.models import ProductStatus
from product_portfolio.models import ScanCodeProject
from product_portfolio.tests import make_product_item_purpose
from product_portfolio.tests import make_product_package
from vulnerabilities.tests import make_vulnerability
from workflow.models import RequestTemplate
Expand Down Expand Up @@ -740,12 +740,28 @@ def test_productcomponent_model_is_custom_component(self):
pc1.save()
self.assertFalse(pc1.is_custom_component)

def test_product_relationship_queryset_vulnerable(self):
pp1 = make_product_package(self.product1)
product_package_qs = ProductPackage.objects.vulnerable()
self.assertEqual(0, product_package_qs.count())

pp1.raw_update(weighted_risk_score=5.0)
product_package_qs = ProductPackage.objects.vulnerable()
self.assertEqual(1, product_package_qs.count())
self.assertIn(pp1, product_package_qs)

def test_product_relationship_queryset_annotate_weighted_risk_score(self):
purpose1 = make_product_item_purpose(self.dataspace, exposure_factor=0.5)
package1 = make_package(self.dataspace, risk_score=4.0)
make_product_package(self.product1, package=package1, purpose=purpose1)

product_package_qs = ProductPackage.objects.annotate_weighted_risk_score()
self.assertEqual(0.5, product_package_qs[0].exposure_factor)
self.assertEqual(4.0, product_package_qs[0].risk_score)
self.assertEqual(2.0, product_package_qs[0].computed_weighted_risk_score)

def test_product_relationship_queryset_update_weighted_risk_score(self):
purpose1 = ProductItemPurpose.objects.create(
label="Core",
text="Text",
dataspace=self.dataspace,
)
purpose1 = make_product_item_purpose(self.dataspace)

# 1. package.risk_score = None, purpose = None
package1 = make_package(self.dataspace)
Expand All @@ -756,22 +772,22 @@ def test_product_relationship_queryset_update_weighted_risk_score(self):
self.assertIsNone(pp1.weighted_risk_score)

# 2. package.risk_score = 4.0, purpose.exposure_factor = None
Package.objects.filter(pk=package1.pk).update(risk_score=4.0)
ProductPackage.objects.filter(pk=pp1.pk).update(purpose=purpose1)
package1.raw_update(risk_score=4.0)
pp1.raw_update(purpose=purpose1)
updated_count = ProductPackage.objects.update_weighted_risk_score()
self.assertEqual(1, updated_count)
pp1.refresh_from_db()
self.assertEqual(4.0, pp1.weighted_risk_score)

# 3. package.risk_score = 4.0, purpose.exposure_factor = 0.5
ProductItemPurpose.objects.filter(pk=purpose1.pk).update(exposure_factor=0.5)
purpose1.raw_update(exposure_factor=0.5)
updated_count = ProductPackage.objects.update_weighted_risk_score()
self.assertEqual(1, updated_count)
pp1.refresh_from_db()
self.assertEqual(2.0, pp1.weighted_risk_score)

# 4. package.risk_score = None, purpose.exposure_factor = 0.5
Package.objects.filter(pk=package1.pk).update(risk_score=None)
package1.raw_update(risk_score=None)
updated_count = ProductPackage.objects.update_weighted_risk_score()
self.assertEqual(1, updated_count)
pp1.refresh_from_db()
Expand Down

0 comments on commit 4532bc2

Please sign in to comment.