From fca35a1151e16c9c6fc9a4c4f869af3f8c1d1eff Mon Sep 17 00:00:00 2001 From: Stuart Axon Date: Thu, 23 Jun 2022 18:15:10 +0100 Subject: [PATCH] This PR builds on the work in the initial PR to move business rules to celery along with info learned deploying this. Avoid filling the task queue with orchestration tasks and starving the workers. =============================================================================== In the previous system there were about 3 layers of tasks, that orchestrated other tasks, by using the .replace() API in each task. Unfortunately it was possible for celery workers to become full of orchestration tasks leaving no room for the business rule tasks at the bottom of the to actually run. This PR attempts two mitigations: 1. Use celery workflows instead of .replace() This PR builds a celery workflow in the check_workbasket using celery constructs such as chain and group. In theory, since most of the work is done ahead of time the system should have more awareness of the task structure avoiding the issue of starvation. 2. Cancel existing workbasket checks when a new check is requested. When check_workbasket is started, it will attempt to revoke existing check_workbasket tasks for the same workbasket. Treat intermediate data structures as ephemeral =============================================== A celery task may execute at any time, right now - or when a system comes up tomorrow, based on this assumption models such as TrackedModelCheck (which stores the result of a business rule check on a TrackedModel) are no longer passed to celery tasks by ID, instead all the information needed to receate the data is passed to the celery task, this means the system will still work even if developers delete these while it is running. Reduce layers in business rule checking ======================================= BusinessRuleChecker and LinkedModelsBusinessRuleChecker are now the only checkers, these now take BusinessRule instances, instead of being subclassed for each business rule. While more parameters are passed when rules are checked a conceptual layer has been removed and the simplification is reflected with around 20 lines of code being removed from checks.py Celery flower is now very easier to read ======================================== Due to the changes above, the output in celery flower should correspond more closely to a users intentions - ids of models. Content Checksums ================= Result caching now validates using checksums of the content, which should reduce the amount of checking the system needs to do. When a workbasket has been published, it's content could invalidate some content in other unpublished workbaskets, by associating business rule checks with checksums of a models content, any models that do not clash can be skipped. Model checksums (generated by `.content_hash()`) are not currently stored in the database (though it may be desirable to store them on TrackedModels, as it would provide an mechanism to address any content in the system). The checksuming scheme is a combination of the type and a sha256 of the fields in `.copyable_fields` (which should represent the fields a user can edit, but not fields such as pk). Blake3 was tested, as it provides a fast hashing algorithm, in practice it didn't provide much of a speedup over sha256. PK ranges ========= Occasionally workbaskets with many items may need to be checker (the initial workbasket has 9 million items). Based on the observations that the ID column of the contained TrackedModels is mostly continguous, the system allows passing sequences of contiguous TrackedModels specified by tuples of (first_pk, last_pk). This is relatively compact, suitable for passing over the network with celery and readable in Celery flower. This also enables chunking of tasks - further enabled by specifying a maximum amount of items in each tuple. On TrackedModelQueryset `.as_pk_intervals` and `.from_pk_intervals` are provided to go to and from this format. --- checks/checks.py | 341 +++++++-------- checks/migrations/0004_auto_20220718_1653.py | 31 ++ checks/migrations/0005_trackedmodelcheck.py | 53 +++ checks/models.py | 143 ++----- checks/querysets.py | 187 ++------- checks/tasks.py | 387 +++++++++++------- checks/tests/factories.py | 140 +++---- checks/tests/test_checkers.py | 172 +++++--- checks/tests/test_tasks.py | 259 ++++++------ checks/tests/util.py | 2 +- commodities/tests/test_business_rules.py | 2 +- .../0006_modelcelerytask_taskmodel.py | 92 +++++ common/models/__init__.py | 4 + common/models/celerytask.py | 194 +++++++++ common/models/tracked_qs.py | 144 ++++++- common/models/tracked_utils.py | 37 ++ common/models/trackedmodel.py | 21 +- common/models/transactions.py | 2 +- common/models/utils.py | 3 +- common/tests/test_business_rules.py | 5 - common/tests/test_models.py | 2 +- common/tests/util.py | 23 +- .../management/commands/dump_transactions.py | 6 +- footnotes/tests/test_views.py | 2 +- pii-ner-exclude.txt | 65 +++ quotas/models.py | 3 + settings/common.py | 31 +- .../management/commands/list_workbaskets.py | 3 +- workbaskets/management/commands/run_checks.py | 208 ++++++++++ .../management/commands/sync_run_checks.py | 2 +- workbaskets/management/util.py | 15 +- workbaskets/models.py | 37 +- workbaskets/tasks.py | 37 +- workbaskets/tests/test_models.py | 6 +- workbaskets/tests/util.py | 2 +- 35 files changed, 1703 insertions(+), 958 deletions(-) create mode 100644 checks/migrations/0004_auto_20220718_1653.py create mode 100644 checks/migrations/0005_trackedmodelcheck.py create mode 100644 common/migrations/0006_modelcelerytask_taskmodel.py create mode 100644 common/models/celerytask.py create mode 100644 workbaskets/management/commands/run_checks.py diff --git a/checks/checks.py b/checks/checks.py index bd715132a6..64c7eb0dfb 100644 --- a/checks/checks.py +++ b/checks/checks.py @@ -1,226 +1,241 @@ -from functools import cached_property -from typing import Collection -from typing import Dict -from typing import Iterator +import logging +from collections import defaultdict from typing import Optional +from typing import Set from typing import Tuple -from typing import Type -from typing import TypeVar + +from django.conf import settings from checks.models import TrackedModelCheck -from checks.models import TransactionCheck -from common.business_rules import ALL_RULES from common.business_rules import BusinessRule from common.business_rules import BusinessRuleViolation -from common.models.trackedmodel import TrackedModel +from common.models import TrackedModel +from common.models import Transaction from common.models.utils import get_current_transaction from common.models.utils import override_current_transaction -CheckResult = Tuple[bool, Optional[str]] - +logger = logging.getLogger(__name__) -Self = TypeVar("Self") +CheckResult = Tuple[bool, Optional[str]] class Checker: - """ - A ``Checker`` is an object that knows how to perform a certain kind of check - against a model. - - Checkers can be applied against a model. The logic of the checker will be - run and the result recorded as a ``TrackedModelCheck``. - """ - - @cached_property - def name(self) -> str: - """ - The name string that on a per-model basis uniquely identifies the - checker. - - The name should be deterministic (i.e. not rely on the current - environment, memory locations or random data) so that the system can - record the name in the database and later use it to work out whether - this check has been run. The name doesn't need to include any details - about the model. - - By default this is the name of the class, but it can include any other - non-model data that is unique to the checker. For a more complex - example, see ``IndirectBusinessRuleChecker.name``. + @classmethod + def run_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ) -> CheckResult: """ - return type(self).__name__ + Run a single business rule on a single model. - @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: + :return CheckResult, a Tuple(rule_passed: str, violation_reason: Optional[str]). """ - Returns instances of this ``Checker`` that should apply to the model. + logger.debug(f"run_rule %s %s %s", model, rule, transaction.pk) + try: + rule(transaction).validate(model) + logger.debug(f"%s [tx:%s] %s [passed]", model, rule, transaction.pk) + return True, None + except BusinessRuleViolation as violation: + reason = violation.args[0] + logger.debug(f"%s [tx:%s] %s [failed]", model, rule, transaction.pk, reason) + return False, reason - What checks apply to a model is sometimes data-dependent, so it is the - responsibility of the ``Checker`` class to tell the system what - instances of itself it would expect to run against the model. For an - example, see ``IndirectBusinessRuleChecker.checkers_for``. + @classmethod + def apply_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): """ - raise NotImplementedError() + Applies the rule to the model and records success in a + TrackedModelCheck. - def run(self, model: TrackedModel) -> CheckResult: - """Runs Checker-dependent logic and returns an indication of success.""" - raise NotImplementedError() + If a TrackedModelCheck already exists with a matching content checksum it + will be updated, otherwise a new one will be created. - def apply(self, model: TrackedModel, context: TransactionCheck): - """Applies the check to the model and records success.""" + :return: TrackedModelCheck instance containing the result of the check. + During debugging the developer can set settings.RAISE_BUSINESS_RULE_FAILURES + to True to raise business rule violations as exceptions. + """ success, message = False, None try: - with override_current_transaction(context.transaction): - success, message = self.run(model) + with override_current_transaction(transaction): + success, message = cls.run_rule(rule, transaction, model) except Exception as e: success, message = False, str(e) + if settings.RAISE_BUSINESS_RULE_FAILURES: + # RAISE_BUSINESS_RULE_FAILURES can be set by the developer to raise + # Exceptions. + raise finally: - return TrackedModelCheck.objects.create( + check, created = TrackedModelCheck.objects.get_or_create( + { + "successful": success, + "message": message, + "content_hash": model.content_hash().digest(), + }, model=model, - transaction_check=context, - check_name=self.name, - successful=success, - message=message, + check_name=rule.__name__, ) + if not created: + check.successful = success + check.message = message + check.content_hash = model.content_hash().digest() + check.save() + return check + @classmethod + def apply_rule_cached( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): + """ + If a matching TrackedModelCheck instance exists, returns it, otherwise + check rule, and return the result as a TrackedModelCheck instance. -class BusinessRuleChecker(Checker): - """ - A ``Checker`` that runs a ``BusinessRule`` against a model. - - This class is expected to be sub-typed for a specific rule by a call to - ``of()``. + :return: TrackedModelCheck instance containing the result of the check. + """ + try: + check = TrackedModelCheck.objects.get( + model=model, + check_name=rule.__name__, + ) + except TrackedModelCheck.DoesNotExist: + logger.debug( + "apply_rule_cached (no existing check) %s, %s apply rule", + rule.__name__, + transaction, + ) + return cls.apply_rule(rule, transaction, model) + + # Re-run the rule if the content checksum no longer matches that of the previous test. + check_hash = bytes(check.content_hash) + model_hash = model.content_hash().digest() + if check_hash == model_hash: + logger.debug( + "apply_rule_cached (matching content hash) %s, tx: %s, using cached result %s", + rule.__name__, + transaction.pk, + check, + ) + return check - Attributes: - checker_cache (dict): (class attribute) Cache of Business checkers created by ``of()``. - """ + logger.debug( + "apply_rule_cached (check.content_hash != model.content_hash()) %s != %s %s, %s apply rule", + check_hash, + model_hash, + rule.__name__, + transaction, + ) + check.delete() + return cls.apply_rule(rule, transaction, model) - rule: Type[BusinessRule] - _checker_cache: Dict[str, BusinessRule] = {} +class BusinessRuleChecker(Checker): + """Apply BusinessRules specified in a TrackedModels business_rules + attribute.""" @classmethod - def of(cls: Type, rule_type: Type[BusinessRule]) -> Type: + def apply_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): """ - Return a subclass of a Checker, e.g. BusinessRuleChecker, - IndirectBusinessRuleChecker that runs the passed in business rule. - - Example, creating a BusinessRuleChecker for ME32: + Run the current business rule on the model. - >>> BusinessRuleChecker.of(measures.business_rules.ME32) - + :return: TrackedModelCheck instance containing the result of the check. + :raises: ValueError if the rule is not in the model's business_rules attribute - This API is usually called by .applicable_to, however this docstring should - illustrate what it does. + To get a list of applicable rules, get_model_rules can be used. + """ + if rule not in model.business_rules: + raise ValueError( + f"{model} does not have {rule} in its business_rules attribute.", + ) - Checkers are created once and then cached in _checker_cache. + return super().apply_rule(rule, transaction, model) - As well as a small performance improvement, caching aids debugging by ensuring - the same checker instance is returned if the same cls is passed to ``of``. + @classmethod + def get_model_rules(cls, model: TrackedModel, rules: Optional[Set[str]] = None): """ - checker_name = f"{cls.__name__}Of[{rule_type.__module__}.{rule_type.__name__}]" - - # If the checker class was already created, return it. - checker_class = cls._checker_cache.get(checker_name) - if checker_class is not None: - return checker_class - # No existing checker was found, so create it: + :param model: TrackedModel instance + :param rules: Optional list of rule names to filter by. + :return: Dict mapping models to a set of the BusinessRules that apply to them. + """ + model_rules = defaultdict(set) - class BusinessRuleCheckerOf(cls): - # Creating this class explicitly in code is more readable than using type(...) - # Once created the name will be mangled to include the rule to be checked. + for rule in model.business_rules: + if rules is not None and rule.__name__ not in rules: + continue - f"""Apply the following checks as specified in {rule_type.__name__}""" - rule = rule_type + model_rules[model].add(rule) - def __repr__(self): - return f"<{checker_name}>" + # Downcast to a dict - this API (and unit testing) a little more sane. + return {**model_rules} - BusinessRuleCheckerOf.__name__ = checker_name - cls._checker_cache[checker_name] = BusinessRuleCheckerOf - return BusinessRuleCheckerOf +class LinkedModelsBusinessRuleChecker(Checker): + """Apply BusinessRules specified in a TrackedModels indirect_business_rules + attribute to models returned by get_linked_models on those rules.""" @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: - """If the rule attribute on this BusinessRuleChecker matches any in the - supplied TrackedModel instance's business_rules, return it in a list, - otherwise there are no matches so return an empty list.""" - if cls.rule in model.business_rules: - return [cls()] - return [] - - def run(self, model: TrackedModel) -> CheckResult: + def apply_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): """ - :return CheckResult, a Tuple(rule_passed: str, violation_reason: Optional[str]). - """ - transaction = get_current_transaction() - try: - self.rule(transaction).validate(model) - return True, None - except BusinessRuleViolation as violation: - return False, violation.args[0] - + LinkedModelsBusinessRuleChecker assumes that the linked models are + still. -class IndirectBusinessRuleChecker(BusinessRuleChecker): - """ - A ``Checker`` that runs a ``BusinessRule`` against a model that is linked to - the model being checked, and for which a change in the checked model could - result in a business rule failure against the linked model. + the current versions (TODO - ensure a business rule checks this), - This is a base class: subclasses for checking specific rules are created by - calling ``of()``. - """ + :return: TrackedModelCheck instance containing the result of the check. + :raises: ValueError if the rule is not in the model's indirect_business_rules attribute - rule: Type[BusinessRule] - linked_model: TrackedModel - - def __init__(self, linked_model: TrackedModel) -> None: - self.linked_model = linked_model - super().__init__() + get_model_rules should be called to get a list of applicable rules and them models they apply to. + """ + if rule not in model.indirect_business_rules: + raise ValueError( + f"{model} does not have {rule} in its indirect_business_rules attribute.", + ) - @cached_property - def name(self) -> str: - # Include the identity of the linked model in the checker name, so that - # each linked model needs to be checked for all checks to be complete. - return f"{super().name}[{self.linked_model.pk}]" + return super().apply_rule(rule, model.transaction, model) @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: - """Return a set of IndirectBusinessRuleCheckers for every model found on - rule.get_linked_models.""" - rules = set() - transaction = get_current_transaction() - if cls.rule in model.indirect_business_rules: - for linked_model in cls.rule.get_linked_models(model, transaction): - rules.add(cls(linked_model)) - return rules - - def run(self, model: TrackedModel) -> CheckResult: + def get_model_rules(cls, model: TrackedModel, rules: Optional[Set] = None): """ - Return the result of running super.run, passing self.linked_model, and. - - return it as a CheckResult - a Tuple(rule_passed: str, violation_reason: Optional[str]) + :param model: Initial TrackedModel instance + :param rules: Optional list of rule names to filter by. + :return: Dict mapping linked models with sets of the BusinessRules that apply to them. """ - result, message = super().run(self.linked_model) - message = f"{self.linked_model}: " + message if message else None - return result, message + tx = get_current_transaction() + + model_rules = defaultdict(set) + for rule in [*model.indirect_business_rules]: + for linked_model in rule.get_linked_models(model, tx): + if rules is not None and rule.__name__ not in rules: + continue -def checker_types() -> Iterator[Type[Checker]]: - """ - Return all registered Checker types. + model_rules[linked_model].add(rule) - See ``checks.checks.BusinessRuleChecker.of``. - """ - for rule in ALL_RULES: - yield BusinessRuleChecker.of(rule) - yield IndirectBusinessRuleChecker.of(rule) + # Downcast to a dict - this API (and unit testing) a little more sane. + return {**model_rules} -def applicable_to(model: TrackedModel) -> Iterator[Checker]: - """Return instances of any Checker classes applicable to the supplied - TrackedModel instance.""" - for checker_type in checker_types(): - yield from checker_type.checkers_for(model) +# Checkers in priority list order, checkers for linked models come first. +ALL_CHECKERS = { + "LinkedModelsBusinessRuleChecker": LinkedModelsBusinessRuleChecker, + "BusinessRuleChecker": BusinessRuleChecker, +} diff --git a/checks/migrations/0004_auto_20220718_1653.py b/checks/migrations/0004_auto_20220718_1653.py new file mode 100644 index 0000000000..76e3157e3c --- /dev/null +++ b/checks/migrations/0004_auto_20220718_1653.py @@ -0,0 +1,31 @@ +# Generated by Django 3.1.14 on 2022-07-18 16:53 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("checks", "0003_auto_20220525_1046"), + ] + + operations = [ + migrations.RemoveField( + model_name="transactioncheck", + name="head_transaction", + ), + migrations.RemoveField( + model_name="transactioncheck", + name="latest_tracked_model", + ), + migrations.RemoveField( + model_name="transactioncheck", + name="transaction", + ), + migrations.DeleteModel( + name="TrackedModelCheck", + ), + migrations.DeleteModel( + name="TransactionCheck", + ), + ] diff --git a/checks/migrations/0005_trackedmodelcheck.py b/checks/migrations/0005_trackedmodelcheck.py new file mode 100644 index 0000000000..f471a2d07a --- /dev/null +++ b/checks/migrations/0005_trackedmodelcheck.py @@ -0,0 +1,53 @@ +# Generated by Django 3.1.14 on 2022-08-02 20:32 + +import django.db.models.deletion +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("common", "0006_modelcelerytask_taskmodel"), + ("checks", "0004_auto_20220718_1653"), + ] + + operations = [ + migrations.CreateModel( + name="TrackedModelCheck", + fields=[ + ( + "taskmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="common.taskmodel", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("check_name", models.CharField(max_length=255)), + ("successful", models.BooleanField()), + ("message", models.TextField(null=True)), + ("content_hash", models.BinaryField(max_length=32, null=True)), + ( + "model", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="checks", + to="common.trackedmodel", + ), + ), + ], + options={ + "unique_together": {("model", "check_name")}, + }, + bases=("common.taskmodel", models.Model), + ), + ] diff --git a/checks/models.py b/checks/models.py index f9680fa399..971e7648e8 100644 --- a/checks/models.py +++ b/checks/models.py @@ -1,136 +1,34 @@ +import logging + from django.db import models from django.db.models import fields +from polymorphic.managers import PolymorphicManager -from checks.querysets import TransactionCheckQueryset +from checks.querysets import TrackedModelCheckQueryset +from common.models import TimestampedMixin +from common.models.celerytask import TaskModel from common.models.trackedmodel import TrackedModel -from common.models.transactions import Transaction - - -class TransactionCheck(models.Model): - """ - Represents an in-progress or completed check of a transaction for - correctness. - - The ``TransactionCheck`` gets created once the check starts and has a flag - to track completeness. - """ - - transaction = models.ForeignKey( - Transaction, - on_delete=models.CASCADE, - related_name="checks", - ) - - completed = fields.BooleanField(default=False) - """True if all of the checks expected to be carried out against the models - in this transaction have recorded any result.""" - - successful = fields.BooleanField(null=True) - """ - True if all of the checks carried out against the models in this - transaction returned a positive result. - - This value will be null until ``completed`` is `True`. - """ - head_transaction_id: int - head_transaction = models.ForeignKey( - Transaction, - on_delete=models.CASCADE, - ) - """ - The latest transaction in the stream of approved transactions (i.e. in the - REVISION partition) at the moment this check was carried out. +logger = logging.getLogger(__name__) - Once new transactions are commited and the head transaction is no longer the - latest, this check will no longer be an accurate signal of correctness - because the new transactions could include new data which would invalidate - the checks. (Unless the checked transaction < head transaction, in which - case it will always be correct.) - """ - tracked_model_count = fields.PositiveSmallIntegerField() +class TrackedModelCheck(TimestampedMixin, TaskModel): """ - The number of tracked models in the transaction at the moment this check was - carried out. + Represents the result of running a single check against a single model. - If something is removed from the transaction later, the number of tracked - models will no longer match. This is used to detect if the check is now - stale. + Stores `content_hash`, a hash of the content for validity checking of the + stored result. """ - latest_tracked_model = models.ForeignKey( - TrackedModel, - on_delete=models.CASCADE, - null=True, - ) - """ - The latest tracked model in the transaction at the moment this check was - carried out. - - If some models are removed and subsequent ones added to the transaction, the - count may be the same but the latest transaction will have a new primary - key. This is used to detect if the check is now stale. - """ - - model_checks: models.QuerySet["TrackedModelCheck"] - - objects: TransactionCheckQueryset = models.Manager.from_queryset( - TransactionCheckQueryset, - )() - - def save(self, *args, **kwargs): - """Computes the metadata we will need later to detect if the check is - current and fresh.""" - if not self.head_transaction_id: - self.head_transaction = Transaction.approved.last() - - self.tracked_model_count = self.transaction.tracked_models.count() - self.latest_tracked_model = self.transaction.tracked_models.order_by( - "pk", - ).last() - - return super().save(*args, **kwargs) - class Meta: - ordering = ( - "transaction__partition", - "transaction__order", - "head_transaction__partition", - "head_transaction__order", - ) - - constraints = ( - models.CheckConstraint( - check=( - models.Q(completed=False, successful__isnull=True) - | models.Q(completed=True, successful__isnull=False) - ), - name="completed_checks_include_successfulness", - ), - ) - - -class TrackedModelCheck(models.Model): - """ - Represents the result of running a single check against a single model. - - The ``TrackedModelCheck`` only gets created once the check is complete, and - hence success should always be known. The reason is that a single model - check is atomic (i.e. there is no smaller structure) and so it's either done - or not, and it can't be "resumed". - """ + unique_together = ("model", "check_name") + objects = PolymorphicManager.from_queryset(TrackedModelCheckQueryset)() model = models.ForeignKey( TrackedModel, related_name="checks", - on_delete=models.CASCADE, - ) - - transaction_check = models.ForeignKey( - TransactionCheck, - on_delete=models.CASCADE, - related_name="model_checks", + on_delete=models.SET_NULL, + null=True, ) check_name = fields.CharField(max_length=255) @@ -141,3 +39,14 @@ class TrackedModelCheck(models.Model): message = fields.TextField(null=True) """The text content returned by the check, if any.""" + + content_hash = models.BinaryField(max_length=32, null=True) + """ + Hash of the content ('copyable_fields') at the time the data was checked. + """ + + def __str__(self): + if self.successful: + return f"{self.model} {self.check_name} [Passed at {self.updated_at}]" + + return f"{self.model} {self.check_name} [Failed at {self.updated_at}, Message: {self.message}]" diff --git a/checks/querysets.py b/checks/querysets.py index 63a132e891..a429e4c1e3 100644 --- a/checks/querysets.py +++ b/checks/querysets.py @@ -1,173 +1,32 @@ -from django.contrib.postgres.aggregates import BoolOr -from django.db import models -from django.db.models import expressions -from django.db.models.aggregates import Count -from django.db.models.aggregates import Max -from django_cte import CTEQuerySet -from django_cte import With +from django.db.transaction import atomic +from polymorphic.query import PolymorphicQuerySet -from common.models.transactions import Transaction -from common.models.transactions import TransactionPartition -from common.models.utils import LazyTransaction -latest_transaction = LazyTransaction(get_value=Transaction.approved.last) - - -class TransactionCheckQueryset(CTEQuerySet): - currentness_filter = ( - # If the head transaction is ahead of the latest transaction then no new - # transactions have been committed since the check. In practice we only - # expect the head_transaction == latest_transaction but it doesn't hurt - # to be more defensive with a greater than check. - # - # head_transaction >= latest_transaction - models.Q( - head_transaction__partition__gt=latest_transaction.partition, - ) - | models.Q( - head_transaction__partition=latest_transaction.partition, - head_transaction__order__gte=latest_transaction.order, - ) - # If the head transaction was ahead of the checked transaction when the - # check was carried out, and the checked transaction is approved, we - # don't need to update anything because subsequent changes can't affect - # the older transaction. - # - # head_transaction >= checked_transaction AND checked_transaction is - # approved - | models.Q( - head_transaction__partition__gt=models.F("transaction__partition"), - transaction__partition__in=TransactionPartition.approved_partitions(), - ) - | models.Q( - head_transaction__partition=models.F("transaction__partition"), - head_transaction__order__gte=models.F("transaction__order"), - transaction__partition__in=TransactionPartition.approved_partitions(), - ) - ) - - freshness_fields = { - # See the field descriptions on ``TransactionCheck`` for details on - # how these fields are populated and used to calculate freshness. - "tracked_model_count": Count("transaction__tracked_models"), - "latest_tracked_model": Max("transaction__tracked_models__id"), - } - - freshness_annotations = { - f"real_{field}": expr for field, expr in freshness_fields.items() - } - - freshness_filter = ( - # Use the metadata on the transaction check to work out if the check - # still represents the data in the transaction. The "real_" fields are - # expected to be annotated onto the queryset and represent the current - # state of the transaction. - # - # A fresh check is where all the current values match the stored values. - models.Q(**{field: models.F(f"real_{field}") for field in freshness_fields}) - # ...or where there are no models to check, which is valid. - | models.Q( - real_latest_tracked_model__isnull=True, - latest_tracked_model__isnull=True, - ) - ) - - requires_update_filter = (~freshness_filter) | (~currentness_filter) - - requires_update_annotation = expressions.ExpressionWrapper( - expression=requires_update_filter, - output_field=models.fields.BooleanField(), - ) - - def current(self): +class TrackedModelCheckQueryset(PolymorphicQuerySet): + def delete(self): """ - A ``TransactionCheck`` is considered "current" if there hasn't been any - data added after the check that could change the result of the check. + Delete, modified to workaround a python bug that stops delete from + working when some fields are ByteFields. - If the checked transaction is in a draft partition, "current" means no - new transactions have been approved since the check was carried out. If - any have, they will now potentially be in scope of the check. + Details: - If the checked transaction is in an approved partition, "current" means - no transactions were approved between the check happening and the - transaction being committed to the approved partition (but some may have - been added after it, which can't affect its result). - """ - return self.filter(self.currentness_filter) + Using .delete() on a query with ByteFields does not work due to a python bug: + https://github.com/python/cpython/issues/95081 + >>> TrackedModelCheck.objects.filter( + model__transaction__workbasket=workbasket_pk, + ).delete() - def fresh(self): - """ - A ``TransactionCheck`` is considered "fresh" if the transaction that it - checked hasn't been modified since the check was carried out, which - could change the result of the check. + File /usr/local/lib/python3.8/copy.py:161, in deepcopy(x, memo, _nil) + 159 reductor = getattr(x, "__reduce_ex__", None) + 160 if reductor is not None: + --> 161 rv = reductor(4) + 162 else: + 163 reductor = getattr(x, "__reduce__", None) - The ``tracked_model_count`` and ``latest_tracked_model`` of the checked - transaction are cached on the check and used to detect this. - """ - return self.annotate(**self.freshness_annotations).filter(self.freshness_filter) - - def stale(self): - """A ``TransactionCheck`` is considered "stale" if the transaction that - it checked has been modified since the check was carried out, which - could change the result of the check.""" - return self.annotate(**self.freshness_annotations).exclude( - self.freshness_filter, - ) - - def requires_update(self, requirement=True, include_archived=False): - """ - A ``TransactionCheck`` requires an update if it or any check on a - transaction before it in order is stale or no longer current. - - If a ``TransactionCheck`` on an earlier transaction is stale, it means - that transaction has been modified since the check was done, which could - also invalidate any checks of any subsequent transactions. + TypeError: cannot pickle 'memoryview' object - By default transactions in ARCHIVED workbaskets are ignored, since these - workbaskets exist outside of the normal workflow. + Work around this by setting the bytefields to None and then calling delete. """ - - if include_archived: - ignore_filter = {} - else: - ignore_filter = {"transaction__workbasket__status": "ARCHIVED"} - - # First filtering out any objects we should ignore, - # work out for each check whether it alone requires an update, by - # seeing whether it is stale or not current. - basic_info = With( - self.model.objects.exclude(**ignore_filter) - .annotate(**self.freshness_annotations) - .annotate( - requires_update=self.requires_update_annotation, - ), - name="basic_info", - ) - - # Now cascade that result down to any subsequent transactions: if a - # transaction in the same workbasket comes later, then it will also - # require an update. TODO: do stale transactions pollute the update - # check for ever? - sequence_info = With( - basic_info.join(self.model.objects.all(), pk=basic_info.col.pk).annotate( - requires_update=expressions.Window( - expression=BoolOr(basic_info.col.requires_update), - partition_by=models.F("transaction__workbasket"), - order_by=[ - models.F("transaction__order").asc(), - models.F("pk").desc(), - ], - ), - ), - name="sequence_info", - ) - - # Now filter for only the type that we want: checks that either do or do - # not require an update. - return ( - sequence_info.join(self, pk=sequence_info.col.pk) - .with_cte(basic_info) - .with_cte(sequence_info) - .annotate(requires_update=sequence_info.col.requires_update) - .filter(requires_update=requirement) - ) + with atomic(): + self.update(content_hash=None) + return super().delete() diff --git a/checks/tasks.py b/checks/tasks.py index dd7d0981c1..cfc9340405 100644 --- a/checks/tasks.py +++ b/checks/tasks.py @@ -1,172 +1,269 @@ -from itertools import cycle - -from celery import group +""" +Celery tasks and workflow. + +Build a workflow of tasks in one go and to pass to celery. +""" +import logging +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple + +from celery import chain +from celery import chord from celery.utils.log import get_task_logger -from checks.checks import applicable_to -from checks.models import TransactionCheck +from checks.checks import ALL_CHECKERS +from checks.checks import Checker +from checks.models import TrackedModelCheck +from common.business_rules import ALL_RULES +from common.business_rules import BusinessRule from common.celery import app +from common.models.celerytask import ModelCeleryTask +from common.models.celerytask import bind_model_task from common.models.trackedmodel import TrackedModel from common.models.transactions import Transaction -from common.models.transactions import TransactionPartition -from common.models.utils import override_current_transaction +from common.models.utils import get_current_transaction +from workbaskets.models import WorkBasket # Celery logger adds the task id and status and outputs via the worker. logger = get_task_logger(__name__) - -@app.task(time_limit=60) -def check_model(trackedmodel_id: int, context_id: int): - """ - Runs all of the applicable checkers on the passed model ID, and records the - results. - - Model checks are expected (from observation) to be short – on the order of - 6-7 seconds max. So if the check is taking considerably longer than this, it - is probably broken and should be killed to free up the worker. +# Types for passing over celery +CheckerModelRule = Tuple[Checker, TrackedModel, Sequence[BusinessRule]] +"""CheckerModelRule stores a checker, model, and a sequence of rules to apply to it.""" + +ModelPKInterval = Tuple[int, int] +"""ModelPKInterval is a tuple of (first_pk, last_pk) referring to a contiguous range of TrackedModels""" + +TaskInfo = Tuple[int, str] +"""TaskInfo is a tuple of (task_id, task_name) which can be used to create a ModelCeleryTask.""" + + +# def get_checker_model_rules( +# models: Sequence[TrackedModel], +# rule_names: Optional[Set[str]] = None, +# ): +# """ +# Generator of model, rules. +# +# Given a sequence of models and a sequence of checkers +# +# yield (model, [rules...]) +# """ +# +# for model in models: +# for checker in ALL_CHECKERS.values(): +# yield from ( +# (checker, checker_model, checker_rules) +# for checker_model, checker_rules in checker.get_model_rules( +# model, +# rule_names, +# ).items() +# ) + + +@app.task(trail=True) +def check_model( + transaction_pk: int, + model_pk: int, + rule_names: Optional[Sequence[str]] = None, + bind_to_task_kwargs: Optional[Dict] = None, +): """ + Task to check one model against one business rule and record the result. - model: TrackedModel = TrackedModel.objects.get(pk=trackedmodel_id) - context: TransactionCheck = TransactionCheck.objects.get(pk=context_id) - transaction = context.transaction - - with override_current_transaction(transaction): - for check in applicable_to(model): - if not context.model_checks.filter( - model=model, - check_name=check.name, - ).exists(): - # Run the checker on the model and record the result. (This is - # not Celery ``apply`` but ``Checker.apply``). - check.apply(model, context) - - -@app.task -def is_transaction_check_complete(check_id: int) -> bool: - """Checks and returns whether the given transaction check is complete, and - records the success if so.""" - - check: TransactionCheck = TransactionCheck.objects.get(pk=check_id) - check.completed = True + As this is a celery task, parameters are in base formats that can be serialised, such as int and str. - with override_current_transaction(check.transaction): - for model in check.transaction.tracked_models.all(): - applicable_checks = set(check.name for check in applicable_to(model)) - performed_checks = set( - check.model_checks.filter(model=model).values_list( - "check_name", - flat=True, - ), - ) + Run one business rule against one model, this is called as part of the check_models workflow. - if applicable_checks != performed_checks: - check.completed = False - break - - if check.completed: - check.successful = not check.model_checks.filter(successful=False).exists() - logger.info("Completed checking %s", check.transaction.summary) - - check.save() - return check.completed - - -def setup_or_resume_transaction_check(transaction: Transaction): - """Return a current, fresh transaction check for the passed transaction ID - and a list of model IDs that need to be checked.""" - - head_transaction = Transaction.approved.last() + By setting bind_to_task_uuid, the task will be bound to the celery task with the given UUID, + this is useful for tracking the progress of the parent task, and cancelling it if needed. + """ + # XXXX - TODO, re-add note on timings, from Simons original code. + + if rule_names is None: + rule_names = set(ALL_RULES.keys()) + + assert set(ALL_RULES.keys()).issuperset(rule_names) + + transaction = Transaction.objects.get(pk=transaction_pk) + model = TrackedModel.objects.get(pk=model_pk) + successful = True + + for checker in ALL_CHECKERS.values(): + for checker_model, model_rules in checker.get_model_rules( + model, + rule_names, + ).items(): + """get_model_rules will return a different model in the case of + LinkedModelChecker, so the model to check use checker_model.""" + for rule in model_rules: + logger.debug( + "%s rule: %s, tx: %s, model: %s", + checker.__name__, + rule, + transaction, + model, + ) + check_result = checker.apply_rule_cached( + rule, + transaction, + checker_model, + ) + if bind_to_task_kwargs: + logger.debug( + "Binding result %s to task. bind_to_task_kwargs: %s", + check_result.pk, + bind_to_task_kwargs, + ) + bind_model_task(check_result, **bind_to_task_kwargs) + + logger.info( + f"Ran check %s %s", + check_result, + "✅" if check_result.successful else "❌", + ) + successful &= check_result.successful + + return successful + + +@app.task(trail=True) +def check_models_workflow( + pk_intervals: Sequence[ModelPKInterval], + bind_to_task_kwargs: Optional[Dict] = None, + rules: Optional[Sequence[str]] = None, +): + """ + Celery Workflow group containing 'check_model_rule' tasks to run applicable + rules from checkers on supplied models in parallel via a celery group. - existing_checks = TransactionCheck.objects.filter( - transaction=transaction, - head_transaction=head_transaction, - ) + If checkers is None, then default to all applicable checkers + (see get_model_rules) - up_to_date_check = existing_checks.requires_update(False).filter(completed=True) - if up_to_date_check.exists(): - return up_to_date_check.get(), [] + Models checked will be the exact model versions passed in, + this is useful for caching checks, e.g. those of linked_models + where an older model is referenced. - context = existing_checks.requires_update(False).filter(completed=False).last() - if context is None: - context = TransactionCheck( - transaction=transaction, - head_transaction=head_transaction, + Callers should ensure models passed in are the correct version, + e.g. by using override_transaction. + """ + logger.debug("Build check_models_workflow") + + models = TrackedModel.objects.from_pk_intervals(*pk_intervals) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Got %s models", models.count()) + + return chord( + check_model.si( + model.transaction.pk, + model.pk, + rules, + bind_to_task_kwargs, ) - context.save() - - return ( - context, - transaction.tracked_models.values_list("pk", flat=True), + for model in models + )(unbind_model_tasks.si([bind_to_task_kwargs["celery_task_id"]])) + + +@app.task(trail=True) +def cancel_workbasket_checks(workbasket_pk: int): + """Find existing celery tasks and, revoke them ande delete the + ModelCeleryTask objects tracking them.""" + celery_tasks = ( + ModelCeleryTask.objects.filter(celery_task_name="check_workbasket") + .update_task_statuses() + .filter_by_task_kwargs(workbasket_pk=workbasket_pk) ) + # Terminate the existing tasks, using SIGUSR1 which triggers the soft timeout handler. + celery_tasks.revoke(terminate=True, signal="SIGUSR1") -@app.task(bind=True) -def check_transaction(self, transaction_id: int): - """Run and record checks for the passed transaction ID, asynchronously.""" - - transaction = Transaction.objects.get(pk=transaction_id) - check, model_ids = setup_or_resume_transaction_check(transaction) - if check.completed and not any(model_ids): - logger.debug( - "Skipping check of %s because an up-to-date check already exists", - transaction.summary, - ) - return - - # Create a workflow: firstly run all of the model checks (in parallel) and - # then once they are all done see if the transaction check is now complete. - logger.info("Beginning check of %s", transaction.summary) - workflow = group( - check_model.si(*args) for args in zip(model_ids, cycle([check.pk])) - ) | is_transaction_check_complete.si(check.pk) - - # Execute the workflow by replacing this task with it. - return self.replace(workflow) - - -def check_transaction_sync(transaction: Transaction): +@app.task(trail=True) +def get_workbasket_model_pk_intervals(workbasket_pk: int): """ - Run and record checks for the passed transaction ID, syncronously. + Return a list of all models in the workbasket. - This method will run all of the checks one after the other and won't return - until they are complete. This is useful for testing and debugging. + Ordinarily this step doesn't take very long, though for the seed workbasket + of 9 million items it may take around 6 seconds (measured on a consumer + laptop [Ryzen 2500u, 32gb ram]). """ - check, model_ids = setup_or_resume_transaction_check(transaction) - if check.completed and not any(model_ids): - logger.debug( - "Skipping check of transaction %s " - "because an up-to-date check already exists", - transaction.pk, - ) - else: - logger.info("Beginning synchronous check of %s", transaction.summary) - for model_id in model_ids: - check_model(model_id, check.pk) - is_transaction_check_complete(check.pk) - - -@app.task(bind=True, rate_limit="1/m") -def update_checks(self): + workbasket = WorkBasket.objects.get(pk=workbasket_pk) + pks = [*workbasket.tracked_models.as_pk_intervals()] + return pks + + +@app.task(trail=True) +def unbind_model_tasks(task_ids: Sequence[str]): + """Called at the end of a workflow, as there is no ongoing celery task + associated with this data.""" + logger.debug("Task_ids: [%s]", task_ids) + deleted = ModelCeleryTask.objects.filter(celery_task_id__in=task_ids).delete() + logger.debug("Deleted %s ModelCeleryTask objects", deleted[0]) + + +@app.task(bind=True, trail=True) +def check_workbasket( + self, + workbasket_pk: int, + current_transaction_pk: Optional[int] = None, + rules: Optional[Sequence[str]] = None, + clear_cache=False, +): """ - Triggers checking for any transaction that requires an update. - - A rate limit is specified here to mitigate instances where this - task stacks up and prevents other tasks from running by monopolising - the worker. - - TODO: Ensure this task is *not* stacking up and blocking the worker! + Orchestration task, that kicks off a workflow to check all models in the + workbasket. + + Cancels existing tasks if they are running, the system has caching which + will help with overlapping checks, cancelling existing checks will help keep + the celery queue clear of stale tasks, which is makes it easier to manage + when the system is under load. + + :param workbasket_pk: pk of the workbasket to check + :param current_transaction_pk: pk of the current transaction, defaults to the current highest transaction + :param rules: specify rule names to check (defaults to ALL_RULES) [mostly for testing/debugging] + :param clear_cache: clear the cache before checking [mostly for testing/debugging] """ - - ids_require_update = ( - Transaction.objects.exclude( - pk__in=TransactionCheck.objects.requires_update(False).values( - "transaction__pk", - ), - ) - .filter(partition=TransactionPartition.DRAFT) - .values_list("pk", flat=True) + logger.debug( + "check_workbasket, workbasket_pk: %s, current_transaction_pk %s, clear_cache %s", + workbasket_pk, + current_transaction_pk, + clear_cache, ) - # Execute a check for each transaction that requires an update by replacing - # this task with a parallel workflow. - return self.replace(group(check_transaction.si(id) for id in ids_require_update)) + if clear_cache: + # Clearing the cache should not be needed in the usual workflow, but may be useful e.g. if + # business rules are updated and need to be re-run. + TrackedModelCheck.objects.filter( + model__transaction__workbasket__pk=workbasket_pk, + ).delete() + + if current_transaction_pk is None: + current_transaction_pk = ( + get_current_transaction() or Transaction.objects.last() + ).pk + + # Use 'bind_to_task' to pass in the celery task id to associate this task and it's subtasks, while + # the task is running, allowing them to be revoked if the underlying data changes or another copy + # of the task is started. + # + # get_workbasket_model_pk_intervals gets tuples of (first_pk, last_pk), a compact form to + # represent the trackedmodels in the workbasket, which is passed to the subtasks tasks. + return chain( + cancel_workbasket_checks.si(workbasket_pk), + get_workbasket_model_pk_intervals.si(workbasket_pk), + check_models_workflow.s( + bind_to_task_kwargs={ + "celery_task_id": self.request.id, + "celery_task_name": "check_workbasket", + }, + rules=rules, + ), + )() + + +def check_workbasket_sync(workbasket: WorkBasket, clear_cache: bool = False): + # Run the celery task and wait + tx = get_current_transaction() + result = check_workbasket.delay(workbasket.pk, tx.pk, clear_cache) + result.wait() diff --git a/checks/tests/factories.py b/checks/tests/factories.py index 46c247d949..584413c590 100644 --- a/checks/tests/factories.py +++ b/checks/tests/factories.py @@ -1,12 +1,8 @@ from dataclasses import dataclass from typing import Optional -import factory - -from checks import models from checks.checks import Checker from common.models.trackedmodel import TrackedModel -from common.tests import factories @dataclass(frozen=True) @@ -19,71 +15,71 @@ def run(self, model: TrackedModel): return self.success, self.message -class TransactionCheckFactory(factory.django.DjangoModelFactory): - class Meta: - model = models.TransactionCheck - - transaction = factory.SubFactory( - factories.TransactionFactory, - draft=True, - ) - completed = True - successful = True - head_transaction = factory.SubFactory(factories.ApprovedTransactionFactory) - tracked_model_count = factory.LazyAttribute( - lambda check: (len(check.transaction.tracked_models.all())), - ) - latest_tracked_model = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - - class Params: - incomplete = factory.Trait( - completed=False, - successful=None, - ) - - empty = factory.Trait( - latest_tracked_model=None, - tracked_model_count=0, - ) - - -class StaleTransactionCheckFactory(TransactionCheckFactory): - class Meta: - exclude = ("first", "second") - - first = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - second = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - - latest_tracked_model = factory.SelfAttribute("second") - - @classmethod - def _after_postgeneration(cls, instance: TrackedModel, create, results=None): - """Save again the instance if creating and at least one hook ran.""" - super()._after_postgeneration(instance, create, results) - - if create: - assert instance.transaction.tracked_models.count() >= 2 - instance.transaction.tracked_models.first().delete() - - -class TrackedModelCheckFactory(factory.django.DjangoModelFactory): - class Meta: - model = models.TrackedModelCheck - - model = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction_check.transaction"), - ) - transaction_check = factory.SubFactory(TransactionCheckFactory) - check_name = factories.string_sequence() - successful = True - message = None +# class TransactionCheckFactory(factory.django.DjangoModelFactory): +# class Meta: +# model = models.TransactionCheck +# +# transaction = factory.SubFactory( +# factories.TransactionFactory, +# draft=True, +# ) +# completed = True +# successful = True +# head_transaction = factory.SubFactory(factories.ApprovedTransactionFactory) +# tracked_model_count = factory.LazyAttribute( +# lambda check: (len(check.transaction.tracked_models.all())), +# ) +# latest_tracked_model = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# +# class Params: +# incomplete = factory.Trait( +# completed=False, +# successful=None, +# ) +# +# empty = factory.Trait( +# latest_tracked_model=None, +# tracked_model_count=0, +# ) +# +# +# class StaleTransactionCheckFactory(TransactionCheckFactory): +# class Meta: +# exclude = ("first", "second") +# +# first = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# second = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# +# latest_tracked_model = factory.SelfAttribute("second") +# +# @classmethod +# def _after_postgeneration(cls, instance: TrackedModel, create, results=None): +# """Save again the instance if creating and at least one hook ran.""" +# super()._after_postgeneration(instance, create, results) +# +# if create: +# assert instance.transaction.tracked_models.count() >= 2 +# instance.transaction.tracked_models.first().delete() +# +# +# class TrackedModelCheckFactory(factory.django.DjangoModelFactory): +# class Meta: +# model = models.TrackedModelCheck +# +# model = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction_check.transaction"), +# ) +# transaction_check = factory.SubFactory(TransactionCheckFactory) +# check_name = factories.string_sequence() +# successful = True +# message = None diff --git a/checks/tests/test_checkers.py b/checks/tests/test_checkers.py index b61738c027..426c5074b9 100644 --- a/checks/tests/test_checkers.py +++ b/checks/tests/test_checkers.py @@ -1,78 +1,148 @@ -from itertools import chain - import pytest -import checks.tests.factories from checks.checks import BusinessRuleChecker -from checks.checks import IndirectBusinessRuleChecker -from checks.checks import checker_types +from checks.checks import LinkedModelsBusinessRuleChecker + +# from checks.checks import checker_types # TODO +from checks.models import TrackedModelCheck from common.tests import factories -from common.tests.util import TestRule +from common.tests.util import TestRule1 +from common.tests.util import TestRule2 from common.tests.util import add_business_rules pytestmark = pytest.mark.django_db -def test_all_business_rules_have_a_checker(trackedmodel_factory): - """Verify that each BusinessRule has a corresponding Checker.""" - checkers = set(checker.rule for checker in checker_types()) - model_type = trackedmodel_factory._meta.model - model_rules = set( - chain(model_type.business_rules, model_type.indirect_business_rules), - ) - assert checkers.intersection(model_rules) == model_rules - - -def test_business_rules_validation(): - """Verify that ``Checker.apply`` calls ``validate`` on it's matching - BusinessRule.""" +@pytest.mark.parametrize( + "applicable_rules, rule_filter, expected_rules", + [ + ({TestRule1}, None, {TestRule1}), + ({TestRule1, TestRule2}, None, {TestRule1, TestRule2}), + ({TestRule1, TestRule2}, [TestRule1], {TestRule1}), + ({TestRule1}, [TestRule2], set()), + (set(), None, set()), + ], +) +def test_business_rules_validation(applicable_rules, rule_filter, expected_rules): + """Verify that ``BusinessRuleChecker.apply_rule`` calls ``validate`` on it's + matching BusinessRule.""" model = factories.TestModel1Factory.create() - check = checks.tests.factories.TransactionCheckFactory( - transaction=model.transaction, - ) - with add_business_rules(type(model), TestRule): - checker_type = BusinessRuleChecker.of(TestRule) + with add_business_rules(type(model), *applicable_rules): + model_rules = BusinessRuleChecker.get_model_rules(model) + assert isinstance(model_rules, dict) - # Verify the cache returns the same object if .of is called a second time. - assert checker_type is BusinessRuleChecker.of(TestRule) + if not expected_rules: + assert model_rules == {} + return - checkers = checker_type.checkers_for(model) + assert model_rules == {model: expected_rules} + check = BusinessRuleChecker.apply_rule(TestRule1, model.transaction, model) - for checker in checkers: - checker.apply(model, check) - assert TestRule.validate.called_with(model) + assert TestRule1.validate.called_with(model) + assert isinstance(check, TrackedModelCheck) -def test_indirect_business_rule_validation(): - model = factories.TestModel1Factory.create() - descs = set( - factories.TestModelDescription1Factory.create_batch( - size=4, - described_record=model, +@pytest.mark.parametrize( + "checker, expected_error_message_template", + [ + ( + BusinessRuleChecker, + "{model} does not have {rule} in its business_rules attribute.", ), - ) - check = checks.tests.factories.TransactionCheckFactory( - transaction=model.transaction, + ( + LinkedModelsBusinessRuleChecker, + "{model} does not have {rule} in its indirect_business_rules attribute.", + ), + ], +) +def test_business_rules_validation_raises_exception_for_unknown_rule( + checker, + expected_error_message_template, +): + """Verify that calling apply_rule with a different rule to that specified in + a TrackedModels business_rules attribute raises a ValueError.""" + model = factories.TestModel1Factory.create() + expected_error_message = expected_error_message_template.format( + model=model, + rule=TestRule2, ) - with add_business_rules(type(model), TestRule), add_business_rules( + with add_business_rules(type(model), TestRule1), add_business_rules( factories.TestModelDescription1Factory._meta.model, - TestRule, + TestRule1, indirect=True, ): - checker_type = IndirectBusinessRuleChecker.of(TestRule) + model_rules = checker.get_model_rules(model) + assert isinstance(model_rules, dict) + + with pytest.raises(ValueError, match=expected_error_message): + # Calling checker.apply_rule with a rule that doesn't apply should raise an error. + checker.apply_rule(TestRule2, model.transaction, model) + + +# def test_indirect_business_rule_validation(): +# model = factories.TestModel1Factory.create() +# descs = set( +# factories.TestModelDescription1Factory.create_batch( +# size=4, +# described_record=model, +# ), +# ) +# check = checks.tests.factories.TransactionCheckFactory( +# transaction=model.transaction, +# ) +# +# with add_business_rules(type(model), TestRule), add_business_rules( +# factories.TestModelDescription1Factory._meta.model, +# TestRule, +# indirect=True, +# ): +# checker_type = LinkedModelsBusinessRuleChecker.of(TestRule) +# checkers = checker_type.checkers_for(model) +# +# # Assert every checker has a unique name. +# assert len(checkers) == len(set(c.name for c in checkers)) +# +# for checker in checkers: +# checker.apply(model, check) +# +# for desc in descs: +# assert TestRule.validate.called_with(desc) + + +@pytest.mark.parametrize( + "applicable_rules, expected_rules", + [ + ({TestRule1.__name__}, {TestRule1}), + # ({TestRule.__name__, 'other'}, {TestRule}), + # ({'other'}, set()), + # (set(), set()), + ], +) +def test_indirect_business_rules_validation(applicable_rules, expected_rules): + """Verify that ``LinkedModelsBusinessRuleChecker.apply_rule`` calls + ``validate`` on it's matching BusinessRule.""" + model = factories.TestModel1Factory.create() + desc1, desc2 = factories.TestModelDescription1Factory.create_batch( + size=2, + described_record=model, + ) - # Verify the cache returns the same object if .of is called a second time. - assert checker_type is IndirectBusinessRuleChecker.of(TestRule) + with add_business_rules(type(model), TestRule1), add_business_rules( + factories.TestModelDescription1Factory._meta.model, + TestRule1, + indirect=True, + ): - checkers = checker_type.checkers_for(model) + desc1_model_rules = LinkedModelsBusinessRuleChecker.get_model_rules(desc1) + desc2_model_rules = LinkedModelsBusinessRuleChecker.get_model_rules(desc2) - # Assert every checker has a unique name. - assert len(checkers) == len(set(c.name for c in checkers)) + assert desc1_model_rules == {model: {TestRule1}} + assert desc1_model_rules == desc2_model_rules - for checker in checkers: - checker.apply(model, check) + # check = LinkedModelsBusinessRuleChecker.apply_rule(TestRule1, desc.transaction, desc) + LinkedModelsBusinessRuleChecker.apply_rule(TestRule1, model.transaction, model) + # LinkedModelsBusinessRuleChecker.apply_rule(TestRule1, model.transaction, model) - for desc in descs: - assert TestRule.validate.called_with(desc) + TestRule1.validate.assert_called_once_with(model) diff --git a/checks/tests/test_tasks.py b/checks/tests/test_tasks.py index c359332472..12aaf5b293 100644 --- a/checks/tests/test_tasks.py +++ b/checks/tests/test_tasks.py @@ -6,12 +6,7 @@ from pytest_django.asserts import assertQuerysetEqual # type: ignore from checks import tasks -from checks.models import TransactionCheck from checks.tests import factories -from checks.tests.util import assert_requires_update -from common.models.transactions import TransactionPartition -from common.tests import factories as common_factories -from workbaskets.validators import WorkflowStatus pytestmark = pytest.mark.django_db @@ -91,130 +86,130 @@ def test_model_checking(check): assert check.model_checks.filter(successful=True).count() == num_successful -def test_completion_of_transaction_checks(check): - check, num_checks, num_completed, num_successful = check - expect_completed = num_completed == num_checks - expect_successful = (num_successful == num_checks) if expect_completed else None - - complete = tasks.is_transaction_check_complete(check.id) - assert complete == expect_completed - - check.refresh_from_db() - assert check.completed == expect_completed - assert check.successful == expect_successful - - -@pytest.mark.parametrize("check_already_exists", (True, False)) -def test_checking_of_transaction(check, check_already_exists): - check, num_checks, num_completed, num_successful = check - expect_completed = num_completed == num_checks - expect_successful = (num_successful == num_checks) if expect_completed else None - if expect_completed: - check.completed = True - check.successful = expect_successful - check.save() - - transaction = check.transaction - if not check_already_exists: - check.delete() - - # The task will replace itself with a new workflow. Testing this is hard. - # Instead, we will capture the new workflow and assert it is calling the - # right things. This is brittle but probably better than nothing. - with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): - workflow = tasks.check_transaction(transaction.id) # type: ignore - - check = TransactionCheck.objects.filter(transaction=transaction).get() - if expect_completed and check_already_exists: - # If the check is already done, it should be skipped. - assert workflow is None - else: - # If checks need to happen, the workflow should have one check task per - # model and finish with a decide task. - assert transaction.tracked_models.count() == len(workflow.tasks) - model_ids = set(transaction.tracked_models.values_list("id", flat=True)) - for task in workflow.tasks: - model_id, context_id = task.args - model_ids.remove(model_id) - assert task.task == tasks.check_model.name - assert context_id == check.id - - assert workflow.body.task == tasks.is_transaction_check_complete.name - assert workflow.body.args[0] == check.id - - -def test_detecting_of_transactions_to_update(): - head_transaction = common_factories.ApprovedTransactionFactory.create() - - # Transaction with no check - no_check = common_factories.UnapprovedTransactionFactory.create() - - # Transaction that does not require update - no_update = factories.TransactionCheckFactory.create( - head_transaction=head_transaction, - ) - assert_requires_update(no_update, False) - - # Transaction that requires update in DRAFT - draft_update = factories.StaleTransactionCheckFactory.create( - transaction__partition=TransactionPartition.DRAFT, - head_transaction=head_transaction, - ) - assert_requires_update(draft_update, True) - - # Transaction that requires update in REVISION - revision_update = factories.StaleTransactionCheckFactory.create( - transaction__partition=TransactionPartition.REVISION, - transaction__order=-(head_transaction.order), - head_transaction=head_transaction, - ) - assert_requires_update(revision_update, True) - - expected_transaction_ids = {no_check.id, draft_update.transaction.id} - - # The task will replace itself with a new workflow. Testing this is hard. - # Instead, we will capture the new workflow and assert it is calling the - # right things. This is brittle but probably better than nothing. - with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): - workflow = tasks.update_checks() # type: ignore - - assert set(t.task for t in workflow.tasks) == {tasks.check_transaction.name} - assert set(t.args[0] for t in workflow.tasks) == expected_transaction_ids - - -@pytest.mark.parametrize("include_archived", [True, False]) -@pytest.mark.parametrize( - "transaction_partition", [TransactionPartition.DRAFT, TransactionPartition.REVISION] -) -def test_archived_workbasket_checks(include_archived, transaction_partition): - """ - Verify transactions in ARCHIVED workbaskets do not require checking unless - include_archived is True. - """ - head_transaction = common_factories.ApprovedTransactionFactory.create() - - # Transaction that requires update in DRAFT or REVISION - transaction_check = factories.StaleTransactionCheckFactory.create( - transaction__partition=transaction_partition, - head_transaction=head_transaction, - ) - - all_checks = TransactionCheck.objects.filter(pk=transaction_check.pk) - initial_require_update = all_checks.requires_update(True, include_archived) - - # Initially the transaction should require update. - assert initial_require_update.count() == 1 - assert initial_require_update.get().pk == transaction_check.pk - - # Set workbasket status to ARCHIVED and verify requires_update only returns their transaction checks if - # include_archived is True - transaction_check.transaction.workbasket.status = WorkflowStatus.ARCHIVED - transaction_check.transaction.workbasket.save() - - checks_require_update = all_checks.requires_update(True, include_archived) - - if include_archived: - assert checks_require_update.count() == 1 - assert checks_require_update.get().pk == transaction_check.pk - else: - assert checks_require_update.count() == 0 +# def test_completion_of_transaction_checks(check): +# check, num_checks, num_completed, num_successful = check +# expect_completed = num_completed == num_checks +# expect_successful = (num_successful == num_checks) if expect_completed else None +# +# complete = tasks.is_transaction_check_complete(check.id) +# assert complete == expect_completed +# +# check.refresh_from_db() +# assert check.completed == expect_completed +# assert check.successful == expect_successful +# +# +# @pytest.mark.parametrize("check_already_exists", (True, False)) +# def test_checking_of_transaction(check, check_already_exists): +# check, num_checks, num_completed, num_successful = check +# expect_completed = num_completed == num_checks +# expect_successful = (num_successful == num_checks) if expect_completed else None +# if expect_completed: +# check.completed = True +# check.successful = expect_successful +# check.save() +# +# transaction = check.transaction +# if not check_already_exists: +# check.delete() +# +# # The task will replace itself with a new workflow. Testing this is hard. +# # Instead, we will capture the new workflow and assert it is calling the +# # right things. This is brittle but probably better than nothing. +# with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): +# workflow = tasks.check_transaction(transaction.id) # type: ignore +# +# check = TransactionCheck.objects.filter(transaction=transaction).get() +# if expect_completed and check_already_exists: +# # If the check is already done, it should be skipped. +# assert workflow is None +# else: +# # If checks need to happen, the workflow should have one check task per +# # model and finish with a decide task. +# assert transaction.tracked_models.count() == len(workflow.tasks) +# model_ids = set(transaction.tracked_models.values_list("id", flat=True)) +# for task in workflow.tasks: +# model_id, context_id = task.args +# model_ids.remove(model_id) +# assert task.task == tasks.check_model.name +# assert context_id == check.id +# +# assert workflow.body.task == tasks.is_transaction_check_complete.name +# assert workflow.body.args[0] == check.id +# +# +# def test_detecting_of_transactions_to_update(): +# head_transaction = common_factories.ApprovedTransactionFactory.create() +# +# # Transaction with no check +# no_check = common_factories.UnapprovedTransactionFactory.create() +# +# # Transaction that does not require update +# no_update = factories.TransactionCheckFactory.create( +# head_transaction=head_transaction, +# ) +# assert_requires_update(no_update, False) +# +# # Transaction that requires update in DRAFT +# draft_update = factories.StaleTransactionCheckFactory.create( +# transaction__partition=TransactionPartition.DRAFT, +# head_transaction=head_transaction, +# ) +# assert_requires_update(draft_update, True) +# +# # Transaction that requires update in REVISION +# revision_update = factories.StaleTransactionCheckFactory.create( +# transaction__partition=TransactionPartition.REVISION, +# transaction__order=-(head_transaction.order), +# head_transaction=head_transaction, +# ) +# assert_requires_update(revision_update, True) +# +# expected_transaction_ids = {no_check.id, draft_update.transaction.id} +# +# # The task will replace itself with a new workflow. Testing this is hard. +# # Instead, we will capture the new workflow and assert it is calling the +# # right things. This is brittle but probably better than nothing. +# with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): +# workflow = tasks.update_checks() # type: ignore +# +# assert set(t.task for t in workflow.tasks) == {tasks.check_transaction.name} +# assert set(t.args[0] for t in workflow.tasks) == expected_transaction_ids +# +# +# @pytest.mark.parametrize("include_archived", [True, False]) +# @pytest.mark.parametrize( +# "transaction_partition", [TransactionPartition.DRAFT, TransactionPartition.REVISION] +# ) +# def test_archived_workbasket_checks(include_archived, transaction_partition): +# """ +# Verify transactions in ARCHIVED workbaskets do not require checking unless +# include_archived is True. +# """ +# head_transaction = common_factories.ApprovedTransactionFactory.create() +# +# # Transaction that requires update in DRAFT or REVISION +# transaction_check = factories.StaleTransactionCheckFactory.create( +# transaction__partition=transaction_partition, +# head_transaction=head_transaction, +# ) +# +# all_checks = TransactionCheck.objects.filter(pk=transaction_check.pk) +# initial_require_update = all_checks.requires_update(True, include_archived) +# +# # Initially the transaction should require update. +# assert initial_require_update.count() == 1 +# assert initial_require_update.get().pk == transaction_check.pk +# +# # Set workbasket status to ARCHIVED and verify requires_update only returns their transaction checks if +# # include_archived is True +# transaction_check.transaction.workbasket.status = WorkflowStatus.ARCHIVED +# transaction_check.transaction.workbasket.save() +# +# checks_require_update = all_checks.requires_update(True, include_archived) +# +# if include_archived: +# assert checks_require_update.count() == 1 +# assert checks_require_update.get().pk == transaction_check.pk +# else: +# assert checks_require_update.count() == 0 diff --git a/checks/tests/util.py b/checks/tests/util.py index 4da60b856b..aa2a7a1301 100644 --- a/checks/tests/util.py +++ b/checks/tests/util.py @@ -1,6 +1,6 @@ from pytest_django.asserts import assertQuerysetEqual # type: ignore -from checks.models import TransactionCheck +# from checks.models import TransactionCheck # TODO def assert_queryset(queryset, expected): diff --git a/commodities/tests/test_business_rules.py b/commodities/tests/test_business_rules.py index 19be75982b..fede86d18e 100644 --- a/commodities/tests/test_business_rules.py +++ b/commodities/tests/test_business_rules.py @@ -1,7 +1,7 @@ import pytest from django.db import DataError -from checks.tasks import check_transaction_sync +# from checks.tasks import check_transaction_sync # TODO from commodities import business_rules from common.business_rules import BusinessRuleViolation from common.tests import factories diff --git a/common/migrations/0006_modelcelerytask_taskmodel.py b/common/migrations/0006_modelcelerytask_taskmodel.py new file mode 100644 index 0000000000..395b31a2c8 --- /dev/null +++ b/common/migrations/0006_modelcelerytask_taskmodel.py @@ -0,0 +1,92 @@ +# Generated by Django 3.1.14 on 2022-08-02 20:32 + +import django.db.models.deletion +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + ("common", "0005_transaction_index"), + ] + + operations = [ + migrations.CreateModel( + name="TaskModel", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_common.taskmodel_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + ), + migrations.CreateModel( + name="ModelCeleryTask", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "celery_task_name", + models.CharField( + blank=True, + db_index=True, + max_length=64, + null=True, + ), + ), + ("celery_task_id", models.CharField(db_index=True, max_length=64)), + ("last_task_status", models.CharField(max_length=8)), + ( + "object", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="common.taskmodel", + ), + ), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_common.modelcelerytask_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "unique_together": {("celery_task_id", "object")}, + }, + ), + ] diff --git a/common/models/__init__.py b/common/models/__init__.py index a55b86f6fc..f5b7f455ec 100644 --- a/common/models/__init__.py +++ b/common/models/__init__.py @@ -3,6 +3,8 @@ from common.fields import NumericSID from common.fields import ShortDescription from common.fields import SignedIntSID +from common.models.celerytask import ModelCeleryTask +from common.models.celerytask import TaskModel from common.models.mixins import TimestampedMixin from common.models.mixins.description import DescriptionMixin from common.models.mixins.validity import ValidityMixin @@ -13,6 +15,8 @@ __all__ = [ "ApplicabilityCode", + "TaskModel", + "ModelCeleryTask", "NumericSID", "ShortDescription", "SignedIntSID", diff --git a/common/models/celerytask.py b/common/models/celerytask.py new file mode 100644 index 0000000000..103156f81d --- /dev/null +++ b/common/models/celerytask.py @@ -0,0 +1,194 @@ +""" +Provide a way to link a Celery Task (usually referencable from a UUID) to a +django Model. + +This enables retrieving the realtime status of tasks while they are running. + +Once tasks have completed, these models should be deleted. +""" + +from celery.result import AsyncResult +from celery.utils.log import get_task_logger +from django.db import models +from polymorphic.models import PolymorphicModel +from polymorphic.query import PolymorphicQuerySet + +from common.celery import app as celery_app + +logger = get_task_logger(__name__) + + +class TaskModel(PolymorphicModel): + """ + Mixin for models that can be linked to a celery task. + + All celery specific functionality is at the other end of the relationship, + on ModelCeleryTask, leaving an extension point for other non-celery based + implementations. + """ + + +class ModelCeleryTaskQuerySet(PolymorphicQuerySet): + def filter_by_task_status(self, statuses=None): + """ + Note: Passing in task ids that are not known to Celery will return Tasks with 'PENDING' status, + as celery can't know if these are tasks that have not reached the broker yet or just don't exist. + """ + model_task_ids = ( + model_task.pk + for model_task in self + if statuses is None or model_task.get_celery_task_status() in statuses + ) + + return self.filter(pk__in=model_task_ids) + + def filter_by_task_kwargs(self, **kwargs): + def task_kwargs_match(task): + """ + :return: True if all the specified kwargs match those on the task. + """ + if task.kwargs is None: + return False + + for k, v in kwargs.items(): + if k not in task.kwargs or task.kwargs[k] != v: + return False + return True + + model_task_ids = ( + model_task.pk + for model_task in self + if task_kwargs_match(model_task.get_celery_task()) + ) + + return self.filter(pk__in=model_task_ids) + + def filter_by_task_args(self, *args): + model_task_ids = ( + model_task.pk + for model_task in self + if model_task.get_celery_task().result.args == args + ) + + return self.filter(pk__in=model_task_ids) + + def update_task_statuses(self): + """Update the last_task_status of all modeltasks in the queryset from + celery.""" + model_tasks = self # .all() + for model_task in model_tasks: + task_status = model_task.get_celery_task_status() + # 'PENDING' can mean the task is not yet known to celery, + # or it is a task that has not yet reached the broker, if + # the status goes *back* to 'PENDING' from a higher status + # then don't forget it the higher status. + if not model_task.last_task_status or task_status != "PENDING": + model_task.last_task_status = task_status + + self.model.objects.bulk_update( + model_tasks, + ["last_task_status"], + batch_size=2000, + ) + return model_tasks + + def delete_pending_tasks(self): + """""" + return self.filter_by_task_status("PENDING").delete() + + def revoke(self, **kwargs): + for task_id in self.values_list("celery_task_id", flat=True): + task = AsyncResult(task_id) + task.revoke(**kwargs) + + self.delete() + + +class ModelCeleryTask(PolymorphicModel): + """ + Provide a way to link a Celery Task (usually referencable from a UUID) to a + django Model. + + ModelCeleryTask instances should be created at the same time as the Celery Task they are + linked to. + + This is because 'PENDING' in Celery either means the task is queued or is returned for unknown + tasks. + """ + + class Meta: + unique_together = ("celery_task_id", "object") + + objects = ModelCeleryTaskQuerySet.as_manager() + + celery_task_name = models.CharField( + max_length=64, + null=True, + blank=True, + db_index=True, + ) + celery_task_id = models.CharField(max_length=64, db_index=True) + last_task_status = models.CharField(max_length=8) + + object = models.ForeignKey( + "common.TaskModel", + blank=True, + null=True, + default=None, + on_delete=models.CASCADE, + ) + + def get_celery_task(self): + """Get a reference to the Celery task instance.""" + return celery_app.AsyncResult(self.celery_task_id) + + def get_celery_task_status(self): + """Query celery and return the task status.""" + return self.get_celery_task().status + + @classmethod + def bind_model(cls, object: TaskModel, celery_task_id: str, celery_task_name: str): + """Link a Celery Task UUID to a django Model.""" + model_task, created = ModelCeleryTask.objects.get_or_create( + {"celery_task_name": celery_task_name}, + object=object, + celery_task_id=celery_task_id, + ) + if not created: + # Call save to update the last_task_status from celery. + # (on creation, save will have been called by django) + model_task.save() + + logger.debug("Bound celery task %s to %s", celery_task_id, object) + return model_task + + @classmethod + def unbind_model(cls, object: TaskModel): + """Unlink a Celery Task UUID from a django Model.""" + return ModelCeleryTask.objects.filter( + object=object, + ).delete() + + def save(self, *args, **kwargs): + """Override save to update the last_task_status from celery.""" + task_status = self.get_celery_task_status() + if not self.last_task_status or task_status != "PENDING": + # 'PENDING' can mean the task is not yet known to celery, + # or it is a task that has not yet reached the broker, if + # the status goes *back* to 'PENDING' from a higher status + # then don't forget it the higher status. + self.last_task_status = task_status + super().save(*args, **kwargs) + + def __repr__(self): + return f"" + + +def bind_model_task(object: TaskModel, celery_task_id: str, celery_task_name: str): + """Link a Celery Task UUID to a PolymorphicModel instance.""" + return ModelCeleryTask.bind_model(object, celery_task_id, celery_task_name) + + +def unbind_model_task(object: TaskModel): + """Link a Celery Task UUID to a PolymorphicModel model instance.""" + return ModelCeleryTask.unbind_model(object) diff --git a/common/models/tracked_qs.py b/common/models/tracked_qs.py index 2683a11326..5b7cc337c3 100644 --- a/common/models/tracked_qs.py +++ b/common/models/tracked_qs.py @@ -1,8 +1,10 @@ from __future__ import annotations from hashlib import sha256 +from itertools import chain from typing import List +from django.contrib.contenttypes.models import ContentType from django.db.models import Case from django.db.models import CharField from django.db.models import F @@ -24,6 +26,9 @@ from common.util import resolve_path from common.validators import UpdateType +PK_INTERVAL_CHUNK_SIZE = 1024 * 64 +"""Default chunk size for primary key intervals, see: `as_pk_intervals.`""" + class TrackedModelQuerySet( PolymorphicQuerySet, @@ -352,11 +357,142 @@ def follow_path(self, path: str) -> TrackedModelQuerySet: return qs.distinct() + def as_pk_intervals(self, chunk_size=PK_INTERVAL_CHUNK_SIZE): + """ + Given a sequence of primary keys, return interval tuples + of: ((first_pk, last_pk), ...) + + By default[1] this provides a much smaller wire format, than, for instance sending all primary keys, thus being sutable for use in Celery. + In the happy case, a single interval will be returned, but in the case of a large number of primary keys, multiple intervals will be returned - determined by chunk_size, gaps in the original data set will also generate more intervals. Gaps may be generated when users delete items workbaskets. + + + Chunking is provided to make it easy to chunk up data for consumers of this data (e.g. a celery task on the other end). + + Unscientifically testing this on a developers' laptop with the seed workbasket (pk=1) with > 9m models, this takes 9.2 seconds, to generate 3426 interval pairs, with 128kb chunks this generates 3430 pairs. + + [1] Under a pathological case, where every primary key iterated by more than one, this would be worse. + """ + qs = self + if qs.query.order_by != ("pk",): + qs = self.order_by("pk") + + pks = qs.values_list("pk", flat=True) + + model_iterator = iter(pks) + + try: + pk = next(model_iterator) + except StopIteration: + return + + first_pk = pk + item_count = 0 + try: + while True: + item_count += 1 + last_pk = pk + pk = next(model_iterator) + if (item_count > chunk_size) or (pk > last_pk + 1): + # Yield an interval tuple of (first_pk, last_pk) if the amount of items is more than the chunk size, + # or if the pks are not contiguous. + yield first_pk, last_pk + first_pk = pk + item_count = 0 + except StopIteration: + pass + + yield first_pk, pk + + def from_pk_intervals(self, *pk_intervals): + """ + Returns a queryset of TrackedModel objects that match the primary key + tuples, (start, end) + + To generate data in this format call as_pk_ranges on a queryset. + """ + q = Q() + for first_pk, last_pk in pk_intervals: + q |= Q(pk__gte=first_pk, pk__lte=last_pk) + + if not q: + # An empty filter would match everything, so return an empty queryset in that case. + return self.none() + + return self.filter(q) + + @classmethod + def _content_hash(cls, models): + """ + Implementation of content hashing, shared by content_hash and + content_hash_fast, should not be called directly, instead `content_hash` + or `content_hash_fast` should be called which impose order on the + models. + + Code is shared in this private method so the naive and fast + implementations return the same hash. + """ + sha = sha256() + for o in models: + sha.update(o.content_hash().digest()) + return sha + def content_hash(self): """ :return: Combined sha256 hash for all contained TrackedModels. + + Ordering is by TrackedModel primary key, so the hash will be stable across multiple queries. """ - sha = sha256() - for o in self: - sha.update(o.content_hash()) - return sha.digest() + return self._content_hash(self.order_by("pk").iterator()) + + def group_by_type(self): + """Yield a sequence of query sets, where each queryset contains only one + Polymorphic ctype, enabling the use of prefetch and select_related on + them.""" + pks = self.values_list("pk", flat=True) + polymorphic_ctypes = ( + self.non_polymorphic() + .distinct("polymorphic_ctype_id") + .values_list("polymorphic_ctype", flat=True) + ) + + for polymorphic_ctype in polymorphic_ctypes: + # Query contenttypes to get the concrete class instance + klass = ContentType.objects.get_for_id(polymorphic_ctype).model_class() + yield klass.objects.filter(pk__in=pks) + + def select_related_copyable_fields(self): + """Split models into separate querysets, using group_by_type and call + select_related on any related fields found in the `copyable_fields` + attribute.""" + pks = self.values_list("pk", flat=True) + for qs in self.group_by_type(): + # Work out which fields from copyable_fields may be use in select_related + related_fields = [ + field.name + for field in qs.model.copyable_fields + if hasattr(field, "related_query_name") + ] + yield qs.select_related(*related_fields).filter(pk__in=pks) + + def content_hash_fast(self): + """ + Use `select_related_copyable_fields` to call select_related on fields + that will be hashed. + + This increases the speed a little more than 2x, at the expense of keeping the data in memory. + On this developers' laptop 2.3 seconds vs 6.5 for the naive implementation in `content_hash`, + for larger amounts of data the difference got bigger, 23 seconds vs 90, though this may + because more types of data were represented. + + For larger workbaskets batching should be used to keep memory usage withing reasonable bounds. + + The hash value returned here should be the same as that from `content_hash`. + """ + # Fetch data using select_related, at this point the ordering + # will have been lost. + all_models = chain(*self.select_related_copyable_fields()) + + # Sort the data using trackedmodel_ptr_id, since the previous step outputs + # an iterable, sorted is used, instead of order_by on a queryset. + sorted_models = sorted(all_models, key=lambda o: o.trackedmodel_ptr_id) + return self._content_hash(sorted_models) diff --git a/common/models/tracked_utils.py b/common/models/tracked_utils.py index 34dc1de39d..7e1627cc0d 100644 --- a/common/models/tracked_utils.py +++ b/common/models/tracked_utils.py @@ -79,3 +79,40 @@ def get_deferred_set_fields(class_: type[Model]) -> Set[Field]: and hasattr(field.remote_field, "through") and field.remote_field.through._meta.auto_created } + + +def get_field_hashable_string(value): + """ + Given a field return a hashable string, containing the fields type and + value, ensuring uniqueness across types. + + For fields that are TrackedModels, delegate to their content_hash method. + For non TrackedModels return a combination of type and value. + """ + from common.models.trackedmodel import TrackedModel + + value_type = type(value) + if isinstance(value, TrackedModel): + # For TrackedModel fields use their content_hash, the type is still included as a debugging aid. + value_hash = value.content_hash().hexdigest() + return ( + f"{value_type.__module__}:{value_type.__name__}.content_hash={value_hash}" + ) + + return f"{value_type.__module__}:{value_type.__name__}={value}" + + +def get_field_hashable_strings(instance, fields): + """ + Given a model instance, return a dict of {field names: hashable string}, + + This calls `get_field_hashable_string` to generate strings unique to the type and value of the fields. + + :param instance: The model instance to generate hashes for. + :param fields: The fields to get use in the hash. + :return: Dictionary of {field_name: hash} + """ + return { + field.name: get_field_hashable_string(getattr(instance, field.name)) + for field in fields + } diff --git a/common/models/trackedmodel.py b/common/models/trackedmodel.py index 52d3f2c8e4..ca01b2027e 100644 --- a/common/models/trackedmodel.py +++ b/common/models/trackedmodel.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from hashlib import sha256 from json import dumps from typing import Any @@ -31,6 +32,7 @@ from common.models.managers import TrackedModelManager from common.models.tracked_qs import TrackedModelQuerySet from common.models.tracked_utils import get_deferred_set_fields +from common.models.tracked_utils import get_field_hashable_strings from common.models.tracked_utils import get_models_linked_to from common.models.tracked_utils import get_relations from common.models.tracked_utils import get_subrecord_relations @@ -108,8 +110,8 @@ class TrackedModel(PolymorphicModel): default, filters to the 'current' transaction. """ - business_rules: Iterable = () - indirect_business_rules: Iterable = () + business_rules: Sequence = () + indirect_business_rules: Sequence = () record_code: int """ @@ -661,6 +663,7 @@ def get_url_pattern_name_prefix(cls): prefix = cls._meta.verbose_name.replace(" ", "_") return prefix + @lru_cache(maxsize=None) def content_hash(self): """ Hash of the user editable content, used by business rule checks for @@ -668,14 +671,12 @@ def content_hash(self): :return: 32 character sha256 'digest', see hashlib.sha256. """ - content = { - field.name: str(getattr(self, field.name)) for field in self.copyable_fields - } - - # The json encoder ensures a somewhat regular format and everything - # passed to it must be hashable. - hashable = dumps(content).encode("utf-8") + # The json encoder ensures a somewhat regular format and ensures only simple data types can be passed in, + # in testing the speed of json encoding is around the same speed as stringifying. + hashable = dumps(get_field_hashable_strings(self, self.copyable_fields)).encode( + "utf-8", + ) sha = sha256() sha.update(hashable) - return sha.digest() + return sha diff --git a/common/models/transactions.py b/common/models/transactions.py index 48f50a0ac0..27abc61b07 100644 --- a/common/models/transactions.py +++ b/common/models/transactions.py @@ -95,7 +95,7 @@ def preorder_negative_transactions(self) -> None: order += 1 tx.order = order - type(self).objects.bulk_update(transactions, ["order"]) + self.model.objects.bulk_update(transactions, ["order"]) @atomic def move_to_end_of_partition(self, partition) -> None: diff --git a/common/models/utils.py b/common/models/utils.py index 3d655d5d2a..d2c841ef5a 100644 --- a/common/models/utils.py +++ b/common/models/utils.py @@ -65,11 +65,12 @@ def __init__(self, func): self.func = func def __str__(self): - return self.func() + return str(self.func()) @wrapt.decorator def lazy_string(wrapped, instance, *args, **kwargs): + """Decorator that will evaluate the wrapped function when stringified.""" return LazyString(wrapped) diff --git a/common/tests/test_business_rules.py b/common/tests/test_business_rules.py index ffeccbf311..446231709d 100644 --- a/common/tests/test_business_rules.py +++ b/common/tests/test_business_rules.py @@ -20,11 +20,6 @@ pytestmark = pytest.mark.django_db -class TestRule(BusinessRule): - __test__ = False - validate = MagicMock() - - def test_business_rule_violation_message(): model = MagicMock() violation = TestRule(model.transaction).violation(model) diff --git a/common/tests/test_models.py b/common/tests/test_models.py index 7fcbfa4f33..4598835ed2 100644 --- a/common/tests/test_models.py +++ b/common/tests/test_models.py @@ -8,6 +8,7 @@ import common.exceptions import workbaskets.models +from checks.tasks import check_workbasket_sync from common.exceptions import NoIdentifyingValuesGivenError from common.models import TrackedModel from common.models.transactions import Transaction @@ -28,7 +29,6 @@ from regulations.models import Group from regulations.models import Regulation from taric.models import Envelope -from workbaskets.tasks import check_workbasket_sync pytestmark = pytest.mark.django_db diff --git a/common/tests/util.py b/common/tests/util.py index c568ca781b..54225f550c 100644 --- a/common/tests/util.py +++ b/common/tests/util.py @@ -78,15 +78,11 @@ @contextlib.contextmanager -def raises_if(exception, expected): - try: - yield - except exception: - if not expected: - raise +def raises_if(exception, expected, *args, **kwargs): + if expected: + yield from pytest.raises(exception, *args, **kwargs) else: - if expected: - pytest.fail(f"Did not raise {exception}") + yield @contextlib.contextmanager @@ -96,11 +92,16 @@ def add_business_rules( """Attach BusinessRules to a TrackedModel.""" target = f"{'indirect_' if indirect else ''}business_rules" rules = (*rules, *getattr(model, target, [])) - with patch.object(model, target, new=rules): - yield + with patch.object(model, target, new=tuple(rules)): + yield model + + +class TestRule1(BusinessRule): + __test__ = False + validate = MagicMock() -class TestRule(BusinessRule): +class TestRule2(BusinessRule): __test__ = False validate = MagicMock() diff --git a/exporter/management/commands/dump_transactions.py b/exporter/management/commands/dump_transactions.py index 9fffcae9ea..b329481577 100644 --- a/exporter/management/commands/dump_transactions.py +++ b/exporter/management/commands/dump_transactions.py @@ -1,3 +1,5 @@ +import ast +import itertools import os import sys @@ -52,7 +54,7 @@ def add_arguments(self, parser): "with a comma-separated list of workbasket ids." ), nargs="*", - type=int, + type=ast.literal_eval, default=None, action="store", ) @@ -76,7 +78,7 @@ def add_arguments(self, parser): def handle(self, *args, **options): workbasket_ids = options.get("workbasket_ids") if workbasket_ids: - query = dict(id__in=workbasket_ids) + query = dict(id__in=itertools.chain.from_iterable(workbasket_ids)) else: query = dict(status=WorkflowStatus.APPROVED) diff --git a/footnotes/tests/test_views.py b/footnotes/tests/test_views.py index 3f776ee672..14b41e6c26 100644 --- a/footnotes/tests/test_views.py +++ b/footnotes/tests/test_views.py @@ -1,6 +1,7 @@ import pytest from django.core.exceptions import ValidationError +from checks.tasks import check_workbasket_sync from common.tests import factories from common.tests.util import assert_model_view_renders from common.tests.util import get_class_based_view_urls_matching_url @@ -13,7 +14,6 @@ from common.views import TrackedModelDetailMixin from footnotes.models import Footnote from footnotes.views import FootnoteList -from workbaskets.tasks import check_workbasket_sync pytestmark = pytest.mark.django_db diff --git a/pii-ner-exclude.txt b/pii-ner-exclude.txt index 10408aeb88..1e3fd37f4f 100644 --- a/pii-ner-exclude.txt +++ b/pii-ner-exclude.txt @@ -1150,3 +1150,68 @@ param kwargs: Enum sha256 hashlib.sha256 +" Generator of model +XXXX - TODO +Celery Workflow +XXXX TODO +is_transaction_check_complete(check_id +check_id +up_to_date_check.get +rate_limit="1 +" Generator of model +XXXX - TODO +Celery Workflow +XXXX TODO +is_transaction_check_complete(check_id +check_id +up_to_date_check.get +rate_limit="1 +mock.patch("celery.app.task +assert transaction.tracked_models.count +assert workflow.body.args[0 +assert initial_require_update.count +assert checks_require_update.get().pk +TrackedModelsCheck +TrackedModelChecks +the Celery Task +GenericRelation +Unlink a Celery Task UUID +SubFactory +TransactionFactory +Trait +assert instance.transaction.tracked_models.count +TrackedModelCheckFactory(factory.django +Sequence +Found existing TrackedModelsCheck % +check_models(model_pks +Split +trackedmodel_ptr_id +BusinessRules +finish_models_check +Remove +check_workbasket_models( +clear_cache +TrackedModelsCheckStatus(enum +TrackedModelsCheck(TimestampedMixin +WorkbasketCheck(TrackedModelsCheck +TrackedModelsCheckChunks +TaskModel +PolymorphicModel +WorkBasketOutputFormat Enum +TaskInfo +ByteFields +get_or_create +models.content_hash +SubTask Waiting +checks.models +Build +SIGUSR1 +AsyncResults +GroupResults +f"{tuple(res.args +GroupResult +self.stdout.flush +checker.apply_rule +TestRule +LinkedModelsBusinessRuleChecker.of(TestRule +assert len(checkers diff --git a/quotas/models.py b/quotas/models.py index 584cb5bb83..5448155e6c 100644 --- a/quotas/models.py +++ b/quotas/models.py @@ -125,6 +125,9 @@ class QuotaOrderNumberOrigin(TrackedModel, ValidityMixin): UpdateValidity, ) + def __str__(self): + return self.sid + def order_number_in_use(self, transaction): return self.order_number.in_use(transaction) diff --git a/settings/common.py b/settings/common.py index d530cc188b..5d0caedab2 100644 --- a/settings/common.py +++ b/settings/common.py @@ -331,8 +331,8 @@ AWS_S3_SIGNATURE_VERSION = "s3v4" AWS_S3_REGION_NAME = "eu-west-2" +# For info on celery settings see the docs at https://docs.celeryq.dev/en/stable/userguide/configuration.html # Pickle could be used as a serializer here, as this always runs in a DMZ - CELERY_BROKER_URL = os.environ.get("CELERY_BROKER_URL", CACHES["default"]["LOCATION"]) if VCAP_SERVICES.get("redis"): @@ -350,12 +350,29 @@ CELERY_TIMEZONE = TIME_ZONE CELERY_WORKER_POOL_RESTARTS = True # Restart worker if it dies -CELERY_BEAT_SCHEDULE = { - "sqlite_export": { - "task": "exporter.sqlite.tasks.export_and_upload_sqlite", - "schedule": timedelta(minutes=30), - }, -} +CELERY_RESULT_EXTENDED = True # Adds Task name, args, kwargs to results. + +# The following settings are usually useful for development, but not for production. +CELERY_TASK_ALWAYS_EAGER = is_truthy(os.environ.get("CELERY_TASK_ALWAYS_EAGER", "N")) +CELERY_TASK_EAGER_PROPAGATES = is_truthy( + os.environ.get("CELERY_TASK_EAGER_PROPAGATES", "N"), +) +CELERY_TASK_REMOTE_TRACEBACKS = is_truthy( + os.environ.get("CELERY_TASK_REMOTE_TRACEBACKS", "N"), +) + +CELERY_BEAT_SCHEDULE = {} +if False: + CELERY_BEAT_SCHEDULE = { + "sqlite_export": { + "task": "exporter.sqlite.tasks.export_and_upload_sqlite", + "schedule": timedelta(minutes=30), + }, + } + +RAISE_BUSINESS_RULE_FAILURES = is_truthy( + os.environ.get("RAISE_BUSINESS_RULE_FAILURES", "N"), +) SQLITE_EXCLUDED_APPS = [ "checks", diff --git a/workbaskets/management/commands/list_workbaskets.py b/workbaskets/management/commands/list_workbaskets.py index 7eecbe8722..631ecbee2a 100644 --- a/workbaskets/management/commands/list_workbaskets.py +++ b/workbaskets/management/commands/list_workbaskets.py @@ -56,7 +56,8 @@ def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( "workbasket_ids", - help=("Comma-separated list of workbasket ids to filter to"), + nargs="*", + help="Comma-separated list of workbasket ids to filter to", type=ast.literal_eval, ) diff --git a/workbaskets/management/commands/run_checks.py b/workbaskets/management/commands/run_checks.py new file mode 100644 index 0000000000..8b7414f5fd --- /dev/null +++ b/workbaskets/management/commands/run_checks.py @@ -0,0 +1,208 @@ +import logging +import signal +from typing import Any +from typing import Dict +from typing import Optional + +from celery.result import AsyncResult +from celery.result import GroupResult +from django.core.management import BaseCommand +from django.core.management.base import CommandParser + +from checks.models import TrackedModelCheck +from workbaskets.models import WorkBasket + +logger = logging.getLogger(__name__) + + +CLEAR_TO_END_OF_LINE = "\x1b[K" + + +def revoke_task_and_children(task, depth=0): + """ + Revoke a task by task_id. + + Uses SIGUSR1, which invokes the SoftTimeLimitExceeded exception, this is + more friendly than plain terminate, which may kill other tasks in the + worker. + """ + if task.children: + for subtask in task.children: + yield from revoke_task_and_children(subtask, depth + 1) + + task.revoke(terminate=True, signal="SIGUSR1") + yield task, depth + + +class TaskControlMixin: + # Implementing classes can populate this with + IGNORE_TASK_PREFIXES = [] + + def get_readable_task_name(self, node): + """Optionally remove a prefix from the task name (used to remove the + module which is often repeated)""" + task_name = getattr(node, "name") or "" + for prefix in self.IGNORE_TASK_PREFIXES: + unprefixed = task_name.replace(prefix, "") + if unprefixed != task_name: + return unprefixed + + return task_name + + def revoke_task_and_children_and_display_result(self, task): + """Call revoke_task_and_children and display information on each revoked + task.""" + for revoked_task, depth in revoke_task_and_children(task): + if isinstance(task, AsyncResult): + self.stdout.write( + " " * depth + + f"{getattr(revoked_task, 'name', None) or '-'} [{revoked_task.id}] {revoked_task.status}", + ) + + def revoke_task_on_sigint(self, task): + """ + Connect a signal handler to attempt to revoke a task if the user presses + Ctrl+C. + + Due to the way tasks travel through Celery, not all tasks can be + revoked. + """ + + def sigint_handler(sig, frame): + """Revoke celery task with task_id.""" + self.stdout.write(f"Received SIGINT, revoking task {task.id} and children.") + self.revoke_task_and_children_and_display_result(task) + + raise SystemExit(1) + + signal.signal(signal.SIGINT, sigint_handler) + + def display_task(self, node, value, depth): + """Default task display.""" + # For now this only shows args, to avoid some noise - if this is more widely + # used that may be something that could be configured by the caller. + readable_task_name = self.get_readable_task_name( + node, + ) + self.stdout.write( + " " * depth * 2 + f"{readable_task_name} " + f"{tuple(node.args)}", + ) + + def iterate_ongoing_tasks(self, result, ignore_groupresult=True): + """ + Iterate over the ongoing tasks as they are received, track their. + + depth - which is useful for visual formatting. + + Yields: (node, value, depth) + """ + task_depths: Dict[str, int] = {} + task_depths[result.id] = 0 + + for parent_id, node in result.iterdeps(intermediate=True): + value = node.get() + + depth = task_depths.get(parent_id, -1) + if isinstance(node, GroupResult) and ignore_groupresult: + # GroupResult is ignored: store it so looking up depth works + # but do not increase the indent.. + task_depths[node.id] = depth + continue + else: + task_depths[node.id] = depth + 1 + + yield node, value, depth + + def display_ongoing_tasks(self, result): + for node, value, depth in self.iterate_ongoing_tasks(result): + self.display_task(node, value, depth) + + +class Command(TaskControlMixin, BaseCommand): + IGNORE_TASK_PREFIXES = [ + "checks.tasks.", + ] + + passed = 0 + failed = 0 + + help = ( + "Run all business rule checks against a WorkBasket's TrackedModels in Celery." + ) + + def add_arguments(self, parser: CommandParser) -> None: + parser.add_argument("WORKBASKET_PK", type=int) + parser.add_argument("--clear-cache", action="store_true", default=False) + parser.add_argument( + "--throw", + help="Allow failing celery tasks to throw exceptions [dev setting]", + action="store_true", + default=False, + ) + + def display_check_model_task(self, node, value, depth): + model_pk = node.args[1] + check_passed = value + readable_task_name = self.get_readable_task_name(node) + style = self.style.SUCCESS if check_passed else self.style.ERROR + self.stdout.write( + " " * depth * 2 + + f"{readable_task_name} " + + style( + f"{TrackedModelCheck.objects.filter(model=model_pk).last()}", + ), + ) + + def display_task(self, node, value, depth): + """Custom display for check_model tasks, acculate their passes / + fails.""" + task_name = getattr(node, "name", None) + if task_name == "checks.tasks.check_model": + self.display_check_model_task(node, value, depth) + if value: + self.passed += 1 + else: + self.failed += 1 + else: + super().display_task(node, value, depth) + + def handle(self, *args: Any, **options: Any) -> Optional[str]: + from checks.tasks import check_workbasket + + # Get the workbasket first + workbasket = WorkBasket.objects.get( + pk=int(options["WORKBASKET_PK"]), + ) + clear_cache = options["clear_cache"] + rule_names = None + throw = options["throw"] + + # Temporarily display a message while waiting for celery, this will only have time to show up + # if celery isn't working (easy enough on a dev machine), or is busy. + self.stdout.write("Connecting to celery... ⌛", ending="") + self.stdout._out.flush() # self.stdout.flush() doesn't result in any output - should report as a bug to django. + result = check_workbasket.apply_async( + args=( + workbasket.pk, + None, + ), + kwargs={ + "clear_cache": clear_cache, + "rules": rule_names, + }, + throw=throw, + ) + result.wait() + self.stdout.write(f"\r{CLEAR_TO_END_OF_LINE}") + + # Attach a handler to revoke the task and its subtasks if the user presses Ctrl+C + self.revoke_task_on_sigint(result) + + self.display_ongoing_tasks(result) + self.stdout.write() + + style = self.style.ERROR if self.failed else self.style.SUCCESS + self.stdout.write(style(f"Failed: {self.failed}")) + self.stdout.write(style(f"Passed: {self.passed}")) + self.stdout.write() + return 1 if self.failed else 0 diff --git a/workbaskets/management/commands/sync_run_checks.py b/workbaskets/management/commands/sync_run_checks.py index a925c10971..8b16d87f99 100644 --- a/workbaskets/management/commands/sync_run_checks.py +++ b/workbaskets/management/commands/sync_run_checks.py @@ -5,9 +5,9 @@ from django.core.management import BaseCommand from django.core.management.base import CommandParser +from checks.tasks import check_workbasket_sync from workbaskets.management.util import WorkBasketCommandMixin from workbaskets.models import WorkBasket -from workbaskets.tasks import check_workbasket_sync logger = logging.getLogger(__name__) diff --git a/workbaskets/management/util.py b/workbaskets/management/util.py index 3c6be99c34..59c89f64f8 100644 --- a/workbaskets/management/util.py +++ b/workbaskets/management/util.py @@ -22,8 +22,14 @@ def _output_workbasket_readable( self.stdout.write(f"{spaces}reason: {first_line_of(workbasket.reason)}") self.stdout.write(f"{spaces}status: {workbasket.status}") if show_transaction_info: + transactions = workbasket.transactions + first_pk = ( + workbasket.transactions.first().pk if transactions.count() else "-" + ) + last_pk = workbasket.transactions.last().pk if transactions.count() else "-" + self.stdout.write( - f"{spaces}transactions: {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", + f"{spaces}transactions: {first_pk} - {last_pk} [{transactions.count()}]", ) def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs): @@ -32,8 +38,13 @@ def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs ending="" if show_transaction_info else "\n", ) if show_transaction_info: + transactions = workbasket.transactions + first_pk = ( + workbasket.transactions.first().pk if transactions.count() else "-" + ) + last_pk = workbasket.transactions.last().pk if transactions.count() else "-" self.stdout.write( - f", {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", + f", {first_pk} - {last_pk} [{transactions.count()}]", ) def output_workbasket( diff --git a/workbaskets/models.py b/workbaskets/models.py index 53f1e6d0ef..21392f489b 100644 --- a/workbaskets/models.py +++ b/workbaskets/models.py @@ -7,7 +7,6 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.core.exceptions import ValidationError from django.db import models from django.db.models import QuerySet from django.db.models import Subquery @@ -15,7 +14,7 @@ from django_fsm import transition from checks.models import TrackedModelCheck -from checks.models import TransactionCheck +from common.models import ModelCeleryTask from common.models.mixins import TimestampedMixin from common.models.tracked_qs import TrackedModelQuerySet from common.models.trackedmodel import TrackedModel @@ -328,10 +327,11 @@ def submit_for_approval(self): if not self.transactions.exists(): return - if self.unchecked_or_errored_transactions.exists(): - raise ValidationError( - "Transactions have not yet been fully checked or contain errors", - ) + # TODO + # if self.unchecked_or_errored_transactions.exists(): + # raise ValidationError( + # "Transactions have not yet been fully checked or contain errors", + # ) @transition( field=status, @@ -486,23 +486,8 @@ def tracked_model_check_errors(self): ) def delete_checks(self): - """Delete all TrackedModelCheck and TransactionCheck instances related - to the WorkBasket.""" - TrackedModelCheck.objects.filter( - transaction_check__transaction__workbasket=self, - ).delete() - TransactionCheck.objects.filter( - transaction__workbasket=self, - ).delete() - - @property - def unchecked_or_errored_transactions(self): - return self.transactions.exclude( - pk__in=TransactionCheck.objects.requires_update(False) - .filter( - completed=True, - successful=True, - transaction__workbasket=self, - ) - .values("transaction__pk"), - ) + """Delete all TrackedModelCheck and ModelCeleryTask instances related to + the WorkBasket.""" + checks = TrackedModelCheck.objects.filter(model__transaction__workbasket=self) + ModelCeleryTask.objects.filter(object__in=checks).delete() + checks.delete() diff --git a/workbaskets/tasks.py b/workbaskets/tasks.py index cffd3464bb..f66b8ae67c 100644 --- a/workbaskets/tasks.py +++ b/workbaskets/tasks.py @@ -1,11 +1,9 @@ -from celery import group +"""Also see checks.tasks, which contains check_workbasket task which checks +business rules.""" from celery import shared_task from celery.utils.log import get_task_logger from django.db.transaction import atomic -from checks.tasks import check_transaction -from checks.tasks import check_transaction_sync -from common.celery import app from workbaskets.models import WorkBasket # Celery logger adds the task id and status and outputs via the worker. @@ -26,34 +24,3 @@ def transition(instance_id: int, state: str, *args): getattr(instance, state)(*args) instance.save() logger.info("Transitioned workbasket %s to state %s", instance_id, instance.status) - - -@app.task(bind=True) -def check_workbasket(self, workbasket_id: int): - """Run and record transaction checks for the passed workbasket ID, - asynchronously.""" - - workbasket: WorkBasket = WorkBasket.objects.get(pk=workbasket_id) - transactions = workbasket.transactions.values_list("pk", flat=True) - - logger.debug("Setup task to check workbasket %s", workbasket_id) - return self.replace(group(check_transaction.si(id) for id in transactions)) - - -def check_workbasket_sync(workbasket: WorkBasket): - """ - Run and record transaction checks for the passed workbasket ID, - synchronously. - - This method will run all of the checks one after the other and won't return - until they are complete. This is useful for testing and debugging. - """ - transactions = workbasket.transactions.all() - - logger.debug( - "Start synchronous check of workbasket %s with % transactions", - workbasket.pk, - transactions.count(), - ) - for transaction in transactions: - check_transaction_sync(transaction) diff --git a/workbaskets/tests/test_models.py b/workbaskets/tests/test_models.py index acaf9aa8d2..9d014fdc70 100644 --- a/workbaskets/tests/test_models.py +++ b/workbaskets/tests/test_models.py @@ -12,7 +12,7 @@ from common.tests.factories import SeedFileTransactionFactory from common.tests.factories import TransactionFactory from common.tests.factories import WorkBasketFactory -from common.tests.util import TestRule +from common.tests.util import TestRule1 from common.tests.util import add_business_rules from common.tests.util import assert_transaction_order from common.validators import UpdateType @@ -330,10 +330,10 @@ def test_workbasket_clean_does_not_run_business_rules(): moderate sized workbasket will time out a web request.""" model = factories.TestModel1Factory.create() - with add_business_rules(type(model), TestRule): + with add_business_rules(type(model), TestRule1): model.transaction.workbasket.full_clean() - assert TestRule.validate.not_called() + assert TestRule1.validate.not_called() def test_current_transaction_returns_last_approved_transaction( diff --git a/workbaskets/tests/util.py b/workbaskets/tests/util.py index 4cf10a1792..32fd373316 100644 --- a/workbaskets/tests/util.py +++ b/workbaskets/tests/util.py @@ -1,5 +1,5 @@ +from checks.tasks import check_workbasket_sync from workbaskets.models import WorkBasket -from workbaskets.tasks import check_workbasket_sync def assert_workbasket_valid(workbasket: WorkBasket):