Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use current() instead of approved_up_to_transaction() in codebase #1089

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 1 addition & 3 deletions additional_codes/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,8 @@ def clean(self):
def save(self, commit=True):
instance = super().save(commit=False)

tx = WorkBasket.get_current_transaction(self.request)

highest_sid = (
models.AdditionalCode.objects.approved_up_to_transaction(tx).aggregate(
models.AdditionalCode.objects.current().aggregate(
Max("sid"),
)["sid__max"]
) or 0
Expand Down
9 changes: 3 additions & 6 deletions additional_codes/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ class AdditionalCodeViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = [AdditionalCodeFilterBackend]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return (
AdditionalCode.objects.approved_up_to_transaction(tx)
AdditionalCode.objects.current()
.select_related("type")
.prefetch_related("descriptions")
)
Expand All @@ -65,8 +64,7 @@ class AdditionalCodeMixin:
model: Type[TrackedModel] = AdditionalCode

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return AdditionalCode.objects.approved_up_to_transaction(tx).select_related(
return AdditionalCode.objects.current().select_related(
"type",
)

Expand Down Expand Up @@ -208,8 +206,7 @@ class AdditionalCodeDescriptionMixin:
model: Type[TrackedModel] = AdditionalCodeDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return AdditionalCodeDescription.objects.approved_up_to_transaction(tx)
return AdditionalCodeDescription.objects.current()


class AdditionalCodeDescriptionCreate(
Expand Down
8 changes: 3 additions & 5 deletions certificates/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(self, *args, **kwargs):

def filter_certificates_for_sid(self, sid):
certificate_type = self.cleaned_data["certificate_type"]
tx = WorkBasket.get_current_transaction(self.request)
return models.Certificate.objects.approved_up_to_transaction(tx).filter(
return models.Certificate.objects.current().filter(
sid=sid,
certificate_type=certificate_type,
)
Expand All @@ -64,14 +63,13 @@ def next_sid(self, instance):
form's save() method (with its commit param set either to True or
False).
"""
current_transaction = WorkBasket.get_current_transaction(self.request)
# Filter certificate by type and find the highest sid, using regex to
# ignore legacy, non-numeric identifiers
return get_next_id(
models.Certificate.objects.filter(
models.Certificate.objects.current().filter(
sid__regex=r"^[0-9]*$",
certificate_type__sid=instance.certificate_type.sid,
).approved_up_to_transaction(current_transaction),
),
instance._meta.get_field("sid"),
max_len=3,
)
Expand Down
9 changes: 3 additions & 6 deletions certificates/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ class CertificatesViewSet(viewsets.ReadOnlyModelViewSet):
permission_classes = [permissions.IsAuthenticated]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return (
models.Certificate.objects.approved_up_to_transaction(tx)
models.Certificate.objects.current()
.select_related("certificate_type")
.prefetch_related("descriptions")
)
Expand All @@ -60,8 +59,7 @@ class CertificateMixin:
model: Type[TrackedModel] = models.Certificate

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.Certificate.objects.approved_up_to_transaction(tx).select_related(
return models.Certificate.objects.current().select_related(
"certificate_type",
)

Expand Down Expand Up @@ -237,8 +235,7 @@ class CertificateDescriptionMixin:
model: Type[TrackedModel] = models.CertificateDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.CertificateDescription.objects.approved_up_to_transaction(tx)
return models.CertificateDescription.objects.current()


class CertificateCreateDescriptionMixin:
Expand Down
32 changes: 11 additions & 21 deletions commodities/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ def __init__(self, transaction=None):
self.logger = logging.getLogger(type(self).__name__)

def parent_spans_child(self, parent, child) -> bool:
parent_validity = parent.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
child_validity = child.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
parent_validity = parent.indented_goods_nomenclature.version_at().valid_between
child_validity = child.indented_goods_nomenclature.version_at().valid_between
return validity_range_contains_range(parent_validity, child_validity)

def parents_span_childs_future(self, parents, child):
Expand All @@ -59,17 +55,13 @@ def parents_span_childs_future(self, parents, child):
parents_validity = []
for parent in parents:
parents_validity.append(
parent.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between,
parent.indented_goods_nomenclature.version_at().valid_between,
)

# sort by start date so any gaps will be obvious
parents_validity.sort(key=lambda daterange: daterange.lower)

child_validity = child.indented_goods_nomenclature.version_at(
self.transaction,
).valid_between
child_validity = child.indented_goods_nomenclature.version_at().valid_between

if (
not child_validity.upper_inf
Expand Down Expand Up @@ -108,7 +100,7 @@ def validate(self, indent):
from commodities.models.dc import get_chapter_collection

try:
good = indent.indented_goods_nomenclature.version_at(self.transaction)
good = indent.indented_goods_nomenclature.version_at()
except TrackedModel.DoesNotExist:
self.logger.warning(
"Goods nomenclature %s no longer exists at transaction %s "
Expand Down Expand Up @@ -166,10 +158,10 @@ def validate(self, good):

if not (
good.code.is_chapter
or GoodsNomenclatureOrigin.objects.filter(
or GoodsNomenclatureOrigin.objects.current()
.filter(
new_goods_nomenclature__sid=good.sid,
)
.approved_up_to_transaction(good.transaction)
.exists()
):
raise self.violation(
Expand Down Expand Up @@ -252,9 +244,9 @@ class NIG11(ValidityStartDateRules):
def get_objects(self, good):
GoodsNomenclatureIndent = good.indents.model

return GoodsNomenclatureIndent.objects.filter(
return GoodsNomenclatureIndent.objects.current().filter(
indented_goods_nomenclature__sid=good.sid,
).approved_up_to_transaction(self.transaction)
)


class NIG12(DescriptionsRules):
Expand Down Expand Up @@ -305,7 +297,7 @@ def validate(self, association):
goods_nomenclature__sid=association.goods_nomenclature.sid,
valid_between__overlap=association.valid_between,
)
.approved_up_to_transaction(association.transaction)
.current()
.exclude(
id=association.pk,
)
Expand Down Expand Up @@ -351,9 +343,7 @@ def has_violation(self, good):
goods_nomenclature__sid=good.sid,
additional_code__isnull=False,
)
.approved_up_to_transaction(
self.transaction,
)
.current()
.exists()
)

Expand Down
4 changes: 1 addition & 3 deletions commodities/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def init_fields(self):
self.fields[
"end_date"
].help_text = "Leave empty if the footnote is needed for an unlimited time"
self.fields[
"associated_footnote"
].queryset = Footnote.objects.approved_up_to_transaction(self.tx).filter(
self.fields["associated_footnote"].queryset = Footnote.objects.current().filter(
footnote_type__application_code__in=[1, 2],
)
self.fields[
Expand Down
11 changes: 4 additions & 7 deletions commodities/models/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def get_dependent_measures(
measure_qs = Measure.objects.filter(goods_sid_query)

if self.moment.clock_type.is_transaction_clock:
measure_qs = measure_qs.approved_up_to_transaction(self.moment.transaction)
measure_qs = measure_qs.current()
else:
measure_qs = measure_qs.latest_approved()

Expand Down Expand Up @@ -823,7 +823,7 @@ def get_snapshot(
date=snapshot_date,
)

commodities = self._get_snapshot_commodities(transaction, snapshot_date)
commodities = self._get_snapshot_commodities(snapshot_date)

return CommodityTreeSnapshot(
moment=moment,
Expand All @@ -832,7 +832,6 @@ def get_snapshot(

def _get_snapshot_commodities(
self,
transaction: Transaction,
snapshot_date: date,
) -> List[Commodity]:
"""
Expand All @@ -853,12 +852,10 @@ def _get_snapshot_commodities(
that match the latest_version goods.
"""
item_ids = {c.item_id for c in self.commodities if c.obj}
goods = GoodsNomenclature.objects.approved_up_to_transaction(
transaction,
).filter(
goods = GoodsNomenclature.objects.filter(
item_id__in=item_ids,
valid_between__contains=snapshot_date,
)
).current()

latest_versions = get_latest_versions(goods)
pks = {good.pk for good in latest_versions}
Expand Down
6 changes: 3 additions & 3 deletions commodities/models/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def get_url(self):
return reverse("commodity-ui-detail", kwargs={"sid": self.sid})

def get_dependent_measures(self, transaction=None):
return self.measures.model.objects.filter(
return self.measures.model.objects.current().filter(
goods_nomenclature__sid=self.sid,
).approved_up_to_transaction(transaction)
)

@property
def is_taric_code(self) -> bool:
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_good_indents(
) -> QuerySet:
"""Return the related goods indents based on approval status."""
good = self.indented_goods_nomenclature
return good.indents.approved_up_to_transaction(
return good.indents.current(
as_of_transaction or self.transaction,
)

Expand Down
9 changes: 2 additions & 7 deletions commodities/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_commodity_list_queryset():
good_1 = factories.SimpleGoodsNomenclatureFactory.create(item_id="1010000000")
good_2 = factories.SimpleGoodsNomenclatureFactory.create(item_id="1000000000")
tx = Transaction.objects.last()
commodity_count = GoodsNomenclature.objects.approved_up_to_transaction(tx).count()
commodity_count = GoodsNomenclature.objects.current().count()
with override_current_transaction(tx):
qs = view.get_queryset()

Expand Down Expand Up @@ -522,12 +522,7 @@ def test_commodity_footnote_update_success(valid_user_client, date_ranges):
"end_date": "",
}
response = valid_user_client.post(url, data)
tx = Transaction.objects.last()
updated_association = (
FootnoteAssociationGoodsNomenclature.objects.approved_up_to_transaction(
tx,
).first()
)
updated_association = FootnoteAssociationGoodsNomenclature.objects.current().first()
assert response.status_code == 302
assert response.url == updated_association.get_url("confirm-update")

Expand Down
9 changes: 2 additions & 7 deletions commodities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def get_queryset(self):
"""
tx = WorkBasket.get_current_transaction(self.request)
return (
GoodsNomenclature.objects.approved_up_to_transaction(
tx,
)
GoodsNomenclature.objects.current()
.prefetch_related("descriptions")
.as_at_and_beyond(date.today())
.filter(suffix=80)
Expand All @@ -69,10 +67,7 @@ class FootnoteAssociationMixin:
model = FootnoteAssociationGoodsNomenclature

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return FootnoteAssociationGoodsNomenclature.objects.approved_up_to_transaction(
tx,
)
return self.model.objects.current()


class CommodityList(CommodityMixin, WithPaginationListView):
Expand Down
12 changes: 6 additions & 6 deletions common/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_linked_models(
related_instances = [getattr(model, field.name)]
for instance in related_instances:
try:
yield instance.version_at(transaction)
yield instance.version_at()
except TrackedModel.DoesNotExist:
# `related_instances` will contain all instances, even
# deleted ones, and `version_at` will return
Expand Down Expand Up @@ -278,8 +278,8 @@ def validate(self, model):

if (
type(model)
.objects.filter(**query)
.approved_up_to_transaction(self.transaction)
.objects.current()
.filter(**query)
.exclude(version_group=model.version_group)
.exists()
):
Expand All @@ -305,8 +305,8 @@ def validate(self, model):
query["valid_between__overlap"] = model.valid_between

if (
model.__class__.objects.filter(**query)
.approved_up_to_transaction(self.transaction)
model.__class__.objects.current()
.filter(**query)
.exclude(version_group=model.version_group)
.exists()
):
Expand Down Expand Up @@ -573,7 +573,7 @@ def validate(self, exclusion):
Membership = geo_group._meta.get_field("members").related_model

if (
not Membership.objects.approved_up_to_transaction(self.transaction)
not Membership.objects.current()
.filter(
geo_group__sid=geo_group.sid,
member__sid=exclusion.excluded_geographical_area.sid,
Expand Down
14 changes: 9 additions & 5 deletions common/models/tracked_qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def latest_approved(self) -> TrackedModelQuerySet:
update_type=UpdateType.DELETE,
)

def current(self) -> TrackedModelQuerySet:
def current(self, transaction=None) -> TrackedModelQuerySet:
Copy link
Collaborator

@paulpepper-trade paulpepper-trade Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing a transaction parameter changes the meaning / intent of this function. I think retaining and using approved_up_to_transaction() is a clearer way of filtering tracked models to up to a specific transaction that is not the current one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... TrackedModelQuerySet.current() is a convenience and short-cut to calling TrackedModelQuerySet.approved_up_to_transaction() for specific situations.

"""
Returns a queryset of approved versions of the model up to the globally
defined current transaction (see ``common.models.utils`` for details of
Expand All @@ -64,12 +64,16 @@ def current(self) -> TrackedModelQuerySet:
(see ``set_current_transaction()`` and ``override_current_transaction()``
in ``common.models.utils``).
"""
return self.approved_up_to_transaction(
LazyTransaction(get_value=get_current_transaction),
)
if transaction:
return self.approved_up_to_transaction(transaction)
else:
return self.approved_up_to_transaction(
LazyTransaction(get_value=get_current_transaction),
)

def approved_up_to_transaction(self, transaction=None) -> TrackedModelQuerySet:
"""Get the approved versions of the model being queried, unless there
"""This function is called using the current() function instead of directly calling it on model queries.
Get the approved versions of the model being queried, unless there
exists a version of the model in a draft state within a transaction
preceding (and including) the given transaction in the workbasket of the
given transaction."""
Expand Down
Loading
Loading