Skip to content

Commit

Permalink
Fix DateOnly.Add{Days,Months,Years}
Browse files Browse the repository at this point in the history
Fixes #2888
  • Loading branch information
roji committed Nov 19, 2023
1 parent fa5ba2e commit d7cdbac
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ public class NpgsqlDateTimeMethodTranslator : IMethodCallTranslator
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), new[] { typeof(double) })!, "secs" },
//{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), new[] { typeof(double) })!, "milliseconds" }

{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), new[] { typeof(int) })!, "years" },
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), new[] { typeof(int) })!, "months" },
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), new[] { typeof(int) })!, "days" },
// DateOnly.AddDays, AddMonths and AddYears have a specialized translation, see below
{ typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddHours), new[] { typeof(int) })!, "hours" },
{ typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddMinutes), new[] { typeof(int) })!, "mins" },
};
Expand Down Expand Up @@ -60,6 +58,15 @@ private static readonly MethodInfo DateOnly_Distance
= typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod(
nameof(NpgsqlDbFunctionsExtensions.Distance), new[] { typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly) })!;

private static readonly MethodInfo DateOnly_AddDays
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), new[] { typeof(int) })!;

private static readonly MethodInfo DateOnly_AddMonths
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), new[] { typeof(int) })!;

private static readonly MethodInfo DateOnly_AddYears
= typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), new[] { typeof(int) })!;

private static readonly MethodInfo TimeOnly_FromDateTime
= typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.FromDateTime), new[] { typeof(DateTime) })!;

Expand Down Expand Up @@ -118,60 +125,21 @@ public NpgsqlDateTimeMethodTranslator(
MethodInfo method,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
=> TranslateDatePart(instance, method, arguments)
?? TranslateDateTime(instance, method, arguments)
?? TranslateDateOnly(instance, method, arguments)
?? TranslateTimeOnly(instance, method, arguments)
?? TranslateTimeZoneInfo(method, arguments);
=> TranslateDateTime(instance, method, arguments)
?? TranslateDateOnly(instance, method, arguments)
?? TranslateTimeOnly(instance, method, arguments)
?? TranslateTimeZoneInfo(method, arguments)
?? TranslateDatePart(instance, method, arguments);

private SqlExpression? TranslateDatePart(
SqlExpression? instance,
MethodInfo method,
IReadOnlyList<SqlExpression> arguments)
{
if (instance is null || !MethodInfoDatePartMapping.TryGetValue(method, out var datePart))
{
return null;
}

if (arguments[0] is not { } interval)
{
return null;
}

// Note: ideally we'd simply generate a PostgreSQL interval expression, but the .NET mapping of that is TimeSpan,
// which does not work for months, years, etc. So we generate special fragments instead.
if (interval is SqlConstantExpression constantExpression)
{
// We generate constant intervals as INTERVAL '1 days'
if (constantExpression.Type == typeof(double)
&& ((double)constantExpression.Value! >= int.MaxValue || (double)constantExpression.Value <= int.MinValue))
{
return null;
}

interval = _sqlExpressionFactory.Fragment(FormattableString.Invariant($"INTERVAL '{constantExpression.Value} {datePart}'"));
}
else
{
// For non-constants, we can't parameterize INTERVAL '1 days'. Instead, we use CAST($1 || ' days' AS interval).
// Note that a make_interval() function also exists, but accepts only int (for all fields except for
// seconds), so we don't use it.
// Note: we instantiate SqlBinaryExpression manually rather than via sqlExpressionFactory because
// of the non-standard Add expression (concatenate int with text)
interval = _sqlExpressionFactory.Convert(
new SqlBinaryExpression(
ExpressionType.Add,
_sqlExpressionFactory.Convert(interval, typeof(string), _textMapping),
_sqlExpressionFactory.Constant(' ' + datePart, _textMapping),
typeof(string),
_textMapping),
typeof(TimeSpan),
_intervalMapping);
}

return _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping);
}
=> instance is not null
&& MethodInfoDatePartMapping.TryGetValue(method, out var datePart)
&& CreateIntervalExpression(arguments[0], datePart) is SqlExpression interval
? _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping)
: null;

private SqlExpression? TranslateDateTime(
SqlExpression? instance,
Expand Down Expand Up @@ -270,6 +238,28 @@ public NpgsqlDateTimeMethodTranslator(
typeof(DateTime),
_timestampMapping);
}

// In PG, date + int = date (int interpreted as days)
if (method == DateOnly_AddDays)
{
return _sqlExpressionFactory.Add(instance, arguments[0]);
}

// For months and years, date + interval yields a timestamp (since interval could have a time component), so we need to cast
// the results back to date
if (method == DateOnly_AddMonths
&& CreateIntervalExpression(arguments[0], "months") is SqlExpression interval1)
{
return _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Add(instance, interval1, instance.TypeMapping), typeof(DateOnly));
}

if (method == DateOnly_AddYears
&& CreateIntervalExpression(arguments[0], "years") is SqlExpression interval2)
{
return _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Add(instance, interval2, instance.TypeMapping), typeof(DateOnly));
}
}

return null;
Expand Down Expand Up @@ -360,4 +350,36 @@ public NpgsqlDateTimeMethodTranslator(

return null;
}

private SqlExpression? CreateIntervalExpression(SqlExpression intervalNum, string datePart)
{
// Note: ideally we'd simply generate a PostgreSQL interval expression, but the .NET mapping of that is TimeSpan,
// which does not work for months, years, etc. So we generate special fragments instead.
if (intervalNum is SqlConstantExpression constantExpression)
{
// We generate constant intervals as INTERVAL '1 days'
if (constantExpression.Type == typeof(double)
&& ((double)constantExpression.Value! >= int.MaxValue || (double)constantExpression.Value <= int.MinValue))
{
return null;
}

return _sqlExpressionFactory.Fragment(FormattableString.Invariant($"INTERVAL '{constantExpression.Value} {datePart}'"));
}

// For non-constants, we can't parameterize INTERVAL '1 days'. Instead, we use CAST($1 || ' days' AS interval).
// Note that a make_interval() function also exists, but accepts only int (for all fields except for
// seconds), so we don't use it.
// Note: we instantiate SqlBinaryExpression manually rather than via sqlExpressionFactory because
// of the non-standard Add expression (concatenate int with text)
return _sqlExpressionFactory.Convert(
new SqlBinaryExpression(
ExpressionType.Add,
_sqlExpressionFactory.Convert(intervalNum, typeof(string), _textMapping),
_sqlExpressionFactory.Constant(' ' + datePart, _textMapping),
typeof(string),
_textMapping),
typeof(TimeSpan),
_intervalMapping);
}
}
1 change: 1 addition & 0 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ private SqlBinaryExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression bina
case ExpressionType.Add or ExpressionType.Subtract
when right.Type == typeof(TimeSpan)
&& (left.Type == typeof(DateTime) || left.Type == typeof(DateTimeOffset) || left.Type == typeof(TimeOnly))
|| right.Type == typeof(int) && left.Type == typeof(DateOnly)
|| right.Type.FullName == "NodaTime.Period"
&& left.Type.FullName is "NodaTime.LocalDateTime" or "NodaTime.LocalDate" or "NodaTime.LocalTime"
|| right.Type.FullName == "NodaTime.Duration"
Expand Down
99 changes: 73 additions & 26 deletions test/EFCore.PG.FunctionalTests/Query/GearsOfWarQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,7 @@ WHERE make_date(date_part('year', m."Date")::int, date_part('month', m."Date")::
[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2039")]
public override async Task Where_DateOnly_Year(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.Year == 1990).AsTracking());
await base.Where_DateOnly_Year(async);

AssertSql(
"""
Expand All @@ -489,9 +487,7 @@ WHERE date_part('year', m."Date")::int = 1990

public override async Task Where_DateOnly_Month(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.Month == 11).AsTracking());
await base.Where_DateOnly_Month(async);

AssertSql(
"""
Expand All @@ -503,9 +499,7 @@ WHERE date_part('month', m."Date")::int = 11

public override async Task Where_DateOnly_Day(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.Day == 10).AsTracking());
await base.Where_DateOnly_Day(async);

AssertSql(
"""
Expand All @@ -517,9 +511,7 @@ WHERE date_part('day', m."Date")::int = 10

public override async Task Where_DateOnly_DayOfYear(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.DayOfYear == 314).AsTracking());
await base.Where_DateOnly_DayOfYear(async);

AssertSql(
"""
Expand All @@ -531,9 +523,7 @@ WHERE date_part('doy', m."Date")::int = 314

public override async Task Where_DateOnly_DayOfWeek(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.DayOfWeek == DayOfWeek.Saturday).AsTracking());
await base.Where_DateOnly_DayOfWeek(async);

AssertSql(
"""
Expand All @@ -545,43 +535,100 @@ WHERE floor(date_part('dow', m."Date"))::int = 6

public override async Task Where_DateOnly_AddYears(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.AddYears(3) == new DateOnly(1993, 11, 10)).AsTracking());
await base.Where_DateOnly_AddYears(async);

AssertSql(
"""
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
FROM "Missions" AS m
WHERE m."Date" + INTERVAL '3 years' = DATE '1993-11-10'
WHERE CAST(m."Date" + INTERVAL '3 years' AS date) = DATE '1993-11-10'
""");
}

public override async Task Where_DateOnly_AddMonths(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.AddMonths(3) == new DateOnly(1991, 2, 10)).AsTracking());
await base.Where_DateOnly_AddMonths(async);

AssertSql(
"""
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
FROM "Missions" AS m
WHERE m."Date" + INTERVAL '3 months' = DATE '1991-02-10'
WHERE CAST(m."Date" + INTERVAL '3 months' AS date) = DATE '1991-02-10'
""");
}

public override async Task Where_DateOnly_AddDays(bool async)
{
await base.Where_DateOnly_AddDays(async);

AssertSql(
"""
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
FROM "Missions" AS m
WHERE m."Date" + 3 = DATE '1990-11-13'
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Select_DateOnly_AddDays(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>().Where(m => m.Date.AddDays(3) == new DateOnly(1990, 11, 13)).AsTracking());
ss => ss.Set<Mission>()
// We filter out DateOnly.MinValue which maps to -infinity
.Where(m => m.Date != DateOnly.MinValue)
.Select(m => m.Date.AddDays(3)));

AssertSql(
"""
SELECT m."Id", m."CodeName", m."Date", m."Duration", m."Rating", m."Time", m."Timeline"
@__MinValue_0='01/01/0001' (DbType = Date)

SELECT m."Date" + 3
FROM "Missions" AS m
WHERE m."Date" <> @__MinValue_0
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Select_DateOnly_AddMonths(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>()
// We filter out DateOnly.MinValue which maps to -infinity
.Where(m => m.Date != DateOnly.MinValue)
.Select(m => m.Date.AddMonths(3)));

AssertSql(
"""
@__MinValue_0='01/01/0001' (DbType = Date)

SELECT CAST(m."Date" + INTERVAL '3 months' AS date)
FROM "Missions" AS m
WHERE m."Date" <> @__MinValue_0
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Select_DateOnly_AddYears(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Mission>()
// We filter out DateOnly.MinValue which maps to -infinity
.Where(m => m.Date != DateOnly.MinValue)
.Select(m => m.Date.AddYears(3)));

AssertSql(
"""
@__MinValue_0='01/01/0001' (DbType = Date)

SELECT CAST(m."Date" + INTERVAL '3 years' AS date)
FROM "Missions" AS m
WHERE m."Date" + INTERVAL '3 days' = DATE '1990-11-13'
WHERE m."Date" <> @__MinValue_0
""");
}

Expand Down

0 comments on commit d7cdbac

Please sign in to comment.