Skip to content

Commit

Permalink
Fixed #34838 -- Corrected output_field of resolved columns for Genera…
Browse files Browse the repository at this point in the history
…tedFields.

Thanks Simon Charette for the implementation idea.
  • Loading branch information
pauloxnet authored and felixxm committed Sep 14, 2023
1 parent 969ecb8 commit 68d769e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
12 changes: 12 additions & 0 deletions django/db/models/fields/generated.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
from django.utils.functional import cached_property

from . import NOT_PROVIDED, Field

Expand Down Expand Up @@ -32,6 +33,17 @@ def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs):
self.db_persist = db_persist
super().__init__(**kwargs)

@cached_property
def cached_col(self):
from django.db.models.expressions import Col

return Col(self.model._meta.db_table, self, self.output_field)

def get_col(self, alias, output_field=None):
if alias != self.model._meta.db_table and output_field is None:
output_field = self.output_field
return super().get_col(alias, output_field)

def contribute_to_class(self, *args, **kwargs):
super().contribute_to_class(*args, **kwargs)

Expand Down
36 changes: 35 additions & 1 deletion tests/model_fields/test_generatedfield.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.core.exceptions import FieldError
from django.db import IntegrityError, connection
from django.db.models import F, GeneratedField, IntegerField
from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
from django.db.models.functions import Lower
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature

Expand Down Expand Up @@ -49,6 +49,40 @@ def test_deconstruct(self):
self.assertEqual(args, [])
self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})

def test_get_col(self):
class Square(Model):
side = IntegerField()
area = GeneratedField(expression=F("side") * F("side"), db_persist=True)

col = Square._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, IntegerField)

class FloatSquare(Model):
side = IntegerField()
area = GeneratedField(
expression=F("side") * F("side"),
db_persist=True,
output_field=FloatField(),
)

col = FloatSquare._meta.get_field("area").get_col("alias")
self.assertIsInstance(col.output_field, FloatField)

def test_cached_col(self):
class Sum(Model):
a = IntegerField()
b = IntegerField()
total = GeneratedField(expression=F("a") + F("b"), db_persist=True)

field = Sum._meta.get_field("total")
cached_col = field.cached_col
self.assertIs(field.get_col(Sum._meta.db_table), cached_col)
self.assertIs(field.get_col(Sum._meta.db_table, field), cached_col)
self.assertIsNot(field.get_col("alias"), cached_col)
self.assertIsNot(field.get_col(Sum._meta.db_table, IntegerField()), cached_col)
self.assertIs(cached_col.target, field)
self.assertIsInstance(cached_col.output_field, IntegerField)


class GeneratedFieldTestMixin:
def _refresh_if_needed(self, m):
Expand Down

0 comments on commit 68d769e

Please sign in to comment.