Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use current() instead of approved_up_to_transaction() in codebase #1089

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions commodities/models/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,10 @@ def _get_snapshot_commodities(
that match the latest_version goods.
"""
item_ids = {c.item_id for c in self.commodities if c.obj}
goods = GoodsNomenclature.objects.current().filter(
goods = GoodsNomenclature.objects.filter(
item_id__in=item_ids,
valid_between__contains=snapshot_date,
)
).current()

latest_versions = get_latest_versions(goods)
pks = {good.pk for good in latest_versions}
Expand Down
11 changes: 7 additions & 4 deletions common/models/tracked_qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def latest_approved(self) -> TrackedModelQuerySet:
update_type=UpdateType.DELETE,
)

def current(self) -> TrackedModelQuerySet:
def current(self, transaction=None) -> TrackedModelQuerySet:
Copy link
Collaborator

@paulpepper-trade paulpepper-trade Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing a transaction parameter changes the meaning / intent of this function. I think retaining and using approved_up_to_transaction() is a clearer way of filtering tracked models to up to a specific transaction that is not the current one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... TrackedModelQuerySet.current() is a convenience and short-cut to calling TrackedModelQuerySet.approved_up_to_transaction() for specific situations.

"""
Returns a queryset of approved versions of the model up to the globally
defined current transaction (see ``common.models.utils`` for details of
Expand All @@ -64,9 +64,12 @@ def current(self) -> TrackedModelQuerySet:
(see ``set_current_transaction()`` and ``override_current_transaction()``
in ``common.models.utils``).
"""
return self.approved_up_to_transaction(
LazyTransaction(get_value=get_current_transaction),
)
if transaction:
return self.approved_up_to_transaction(transaction)
else:
return self.approved_up_to_transaction(
LazyTransaction(get_value=get_current_transaction),
)

def approved_up_to_transaction(self, transaction=None) -> TrackedModelQuerySet:
"""This function is called using the current() function instead of directly calling it on model queries.
Expand Down
2 changes: 1 addition & 1 deletion common/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def not_current(self, asof_transaction=None) -> QuerySet:
:param transaction Transaction: The transaction to limit versions to.
:rtype QuerySet:
"""
current = self.current()
current = self.current(transaction=asof_transaction)

return self.difference(current)

Expand Down
3 changes: 0 additions & 3 deletions footnotes/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class FootnoteViewSet(viewsets.ReadOnlyModelViewSet):
]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return (
models.Footnote.objects.current()
nboyse marked this conversation as resolved.
Show resolved Hide resolved
.select_related("footnote_type")
Expand All @@ -67,7 +66,6 @@ class FootnoteMixin:
model: Type[TrackedModel] = models.Footnote

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.Footnote.objects.current().select_related(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
"footnote_type",
)
Expand All @@ -77,7 +75,6 @@ class FootnoteDescriptionMixin:
model: Type[TrackedModel] = models.FootnoteDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return models.FootnoteDescription.objects.current()
nboyse marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
1 change: 0 additions & 1 deletion geo_areas/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def clean(self):
)

if membership and action == GeoMembershipAction.DELETE:
tx = WorkBasket.get_current_transaction(self.request)
if membership.member_used_in_measure_exclusion():
nboyse marked this conversation as resolved.
Show resolved Hide resolved
self.add_error(
"membership",
Expand Down
2 changes: 1 addition & 1 deletion geo_areas/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def test_geo_area_edit_create_view(
area_code=AreaCode.REGION,
area_id="TR",
valid_between=date_ranges.no_end,
transaction=session_workbasket.new_transaction(),
transaction__workbasket=session_workbasket,
)

data_changes = {**date_post_data("end_date", date_ranges.normal.upper)}
Expand Down
2 changes: 0 additions & 2 deletions geo_areas/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ class GeoAreaMixin:
model: Type[TrackedModel] = GeographicalArea

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return GeographicalArea.objects.current()
nboyse marked this conversation as resolved.
Show resolved Hide resolved


class GeoAreaDescriptionMixin:
model: Type[TrackedModel] = GeographicalAreaDescription

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return GeographicalAreaDescription.objects.current()
nboyse marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
4 changes: 2 additions & 2 deletions measures/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ def validate(self, measure):

goods = (
type(measure.goods_nomenclature)
.objects.current()
.filter(
.objects.filter(
sid=measure.goods_nomenclature.sid,
valid_between__overlap=measure.effective_valid_between,
)
.current()
)

explosion_level = measure.measure_type.measure_explosion_level
Expand Down
2 changes: 0 additions & 2 deletions measures/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,6 @@ def __init__(self, *args, **kwargs):
self.request = kwargs.pop("request", None)
super().__init__(*args, **kwargs)

tx = WorkBasket.get_current_transaction(self.request)

self.initial["duty_sentence"] = self.instance.duty_sentence
self.request.session[
f"instance_duty_sentence_{self.instance.sid}"
Expand Down
4 changes: 2 additions & 2 deletions measures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,14 +656,14 @@ def auto_value_fields(cls):

def has_components(self, transaction):
return (
MeasureComponent.objects.current()
MeasureComponent.objects.current(transaction=transaction)
.filter(component_measure__sid=self.sid)
.exists()
)

def has_condition_components(self, transaction):
return (
MeasureConditionComponent.objects.current()
MeasureConditionComponent.objects.current(transaction=transaction)
.filter(condition__dependent_measure__sid=self.sid)
.exists()
)
Expand Down
4 changes: 3 additions & 1 deletion measures/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def duty_sentence(

# Components with the greatest transaction_id that is less than
# or equal to component_parent's transaction_id, are considered 'current'.
component_qs = component_parent.components.current()
component_qs = component_parent.components.current(
transaction=component_parent.transaction
)
if not component_qs:
return ""
latest_transaction_id = component_qs.aggregate(
Expand Down
12 changes: 10 additions & 2 deletions measures/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ class Meta:
measure_type_description = factory.SelfAttribute("measure.measure_type.description")
duty_sentence = factory.sequence(lambda n: f"{n}.00%")
origin_description = factory.LazyAttribute(
lambda m: m.measure.geographical_area.descriptions.current().last().description,
lambda m: m.measure.geographical_area.descriptions.current(
transaction=m.measure.geographical_area.transaction
)
.last()
.description,
)
excluded_origin_descriptions = factory.LazyAttribute(
lambda m: random.choice(MeasureSheetRow.separators).join(
e.excluded_geographical_area.descriptions.current().last().description
e.excluded_geographical_area.descriptions.current(
transaction=e.excluded_geographical_area
)
.last()
.description
for e in m.measure.exclusions.all()
),
)
Expand Down
5 changes: 0 additions & 5 deletions measures/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ def test_measure_update_duty_sentence(
assert response.status_code == 302

if update_data:
tx = Transaction.objects.last()
measure = Measure.objects.current().get(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
sid=measure_form.instance.sid,
)
Expand Down Expand Up @@ -700,7 +699,6 @@ def test_measure_update_create_conditions(
assert response.status_code == 302
assert response.url == reverse("measure-ui-confirm-update", args=(measure.sid,))

tx = Transaction.objects.last()
updated_measure = Measure.objects.current().get(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
sid=measure.sid,
)
Expand Down Expand Up @@ -756,7 +754,6 @@ def test_measure_update_edit_conditions(
client.force_login(valid_user)
client.post(url, data=measure_edit_conditions_data)
transaction_count = Transaction.objects.count()
tx = Transaction.objects.last()
measure_with_condition = Measure.objects.current().get(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
sid=measure.sid,
)
Expand All @@ -771,7 +768,6 @@ def test_measure_update_edit_conditions(
f"{MEASURE_CONDITIONS_FORMSET_PREFIX}-0-applicable_duty"
] = "10 GBP / 100 kg"
client.post(url, data=measure_edit_conditions_data)
tx = Transaction.objects.last()
updated_measure = Measure.objects.current().get(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
sid=measure.sid,
)
Expand Down Expand Up @@ -877,7 +873,6 @@ def test_measure_update_remove_conditions(
# We expect one transaction for the measure update and condition deletion
assert Transaction.objects.count() == transaction_count + 1

tx = Transaction.objects.last()
updated_measure = Measure.objects.current().get(
nboyse marked this conversation as resolved.
Show resolved Hide resolved
sid=measure.sid,
)
Expand Down
4 changes: 0 additions & 4 deletions measures/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class MeasureTypeViewSet(viewsets.ReadOnlyModelViewSet):
filter_backends = [MeasureTypeFilterBackend]

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)
return MeasureType.objects.current().order_by(
"description",
)
Expand All @@ -85,7 +84,6 @@ class MeasureMixin:
model: Type[TrackedModel] = Measure

def get_queryset(self):
tx = WorkBasket.get_current_transaction(self.request)

return Measure.objects.current()

Expand Down Expand Up @@ -1061,7 +1059,6 @@ def get_form_kwargs(self, step):

def get_form(self, step=None, data=None, files=None):
form = super().get_form(step, data, files)
tx = WorkBasket.get_current_transaction(self.request)
forms = [form]
if hasattr(form, "forms"):
forms = form.forms
Expand Down Expand Up @@ -1103,7 +1100,6 @@ def get_form_kwargs(self):

def get_form(self, form_class=None):
form = super().get_form(form_class=form_class)
tx = WorkBasket.get_current_transaction(self.request)

if hasattr(form, "field"):
for field in form.fields.values():
Expand Down
25 changes: 18 additions & 7 deletions quotas/business_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ON2(BusinessRule):
def validate(self, order_number):
if (
type(order_number)
.objects.current()
.objects.current(transaction=order_number.transaction)
.filter(
order_number=order_number.order_number,
valid_between__overlap=order_number.valid_between,
Expand All @@ -70,7 +70,12 @@ def validate(self, order_number):
order_number_versions = order_number.get_versions()
origin_exists = False
for order_number_version in order_number_versions:
if order_number_version.origins.current().count() > 0:
if (
order_number_version.origins.current(
transaction=order_number.transaction
).count()
> 0
):
origin_exists = True
break

Expand All @@ -85,7 +90,7 @@ class ON5(BusinessRule):
def validate(self, origin):
if (
type(origin)
.objects.current()
.objects.current(transaction=origin.transaction)
.filter(
order_number__sid=origin.order_number.sid,
geographical_area__sid=origin.geographical_area.sid,
Expand Down Expand Up @@ -179,7 +184,9 @@ def validate(self, order_number_origin):
check that there are no measures linked to the origin .
"""

measures = measures_models.Measure.objects.current()
measures = measures_models.Measure.objects.current(
transaction=order_number_origin.transaction
)

if not measures.exists():
return
Expand Down Expand Up @@ -317,7 +324,7 @@ class OverlappingQuotaDefinition(BusinessRule):
def validate(self, quota_definition):
potential_quota_definition_matches = (
type(quota_definition)
.objects.current()
.objects.current(transaction=quota_definition.transaction)
.filter(
order_number=quota_definition.order_number,
valid_between__overlap=quota_definition.valid_between,
Expand Down Expand Up @@ -347,7 +354,9 @@ def validate(self, quota_definition):
if quota_definition.valid_between.lower < datetime.date.today():
return True

if quota_definition.sub_quota_associations.current().exists():
if quota_definition.sub_quota_associations.current(
transaction=self.transaction
).exists():
return True

if quota_definition.volume != quota_definition.initial_volume:
Expand Down Expand Up @@ -458,7 +467,9 @@ class QA6(BusinessRule):

def validate(self, association):
if (
association.main_quota.sub_quota_associations.current()
association.main_quota.sub_quota_associations.current(
transaction=association.transaction
)
.values(
"sub_quota_relation_type",
)
Expand Down
Loading