Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for translating Array.IndexOf methods for byte arrays for SqlServer & SQLite #34457

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,106 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;
var methodDefinition = method.GetGenericMethodDefinition();
if (methodDefinition.Equals(EnumerableMethods.Contains))
{
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;

var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);
var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}
return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}

if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate)
&& arguments[0].Type == typeof(byte[]))
{
return _sqlExpressionFactory.Convert(
if (methodDefinition.Equals(EnumerableMethods.FirstWithoutPredicate))
{
return _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"SUBSTRING",
[arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)],
nullable: true,
argumentsPropagateNullability: [true, true, true],
typeof(byte[])),
method.ReturnType);
}

if (methodDefinition.Equals(ArrayMethods.IndexOf))
{
return TranslateIndexOf(method, arguments[0], arguments[1], null);
}

if (methodDefinition.Equals(ArrayMethods.IndexOfWithStartingPosition))
{
return TranslateIndexOf(method, arguments[0], arguments[1], arguments[2]);
}
}

return null;
}

private SqlExpression TranslateIndexOf(
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
MethodInfo method,
SqlExpression source,
SqlExpression valueToSearch,
SqlExpression? startIndex
)
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
{
var sourceTypeMapping = source.TypeMapping;
var sqlArguments = new List<SqlExpression>
{
valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(valueToSearch, typeof(byte[]), sourceTypeMapping),
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
source
};

if (startIndex is not null)
{
sqlArguments.Add(
startIndex is SqlConstantExpression { Value : int index }
? _sqlExpressionFactory.Constant(index + 1, typeof(int))
: _sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1))
);
}

var argumentsPropagateNullability = Enumerable.Repeat(true, sqlArguments.Count);

SqlExpression charIndexExpr;
var storeType = sourceTypeMapping?.StoreType;
if (storeType == "varbinary(max)")
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
{
charIndexExpr = _sqlExpressionFactory.Function(
"CHARINDEX",
sqlArguments,
nullable: true,
argumentsPropagateNullability: argumentsPropagateNullability,
typeof(long));

charIndexExpr = _sqlExpressionFactory.Convert(charIndexExpr, typeof(int));
}
else
{
charIndexExpr = _sqlExpressionFactory.Function(
"CHARINDEX",
sqlArguments,
nullable: true,
argumentsPropagateNullability: argumentsPropagateNullability,
method.ReturnType);
}


return _sqlExpressionFactory.Subtract(charIndexExpr, _sqlExpressionFactory.Constant(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,30 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var genericMethodDefinition = method.GetGenericMethodDefinition();
if (genericMethodDefinition.Equals(EnumerableMethods.Contains))
{
return _sqlExpressionFactory.GreaterThan(
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(0));

var value = arguments[1] is SqlConstantExpression constantValue
? (SqlExpression)_sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
new[] { arguments[1] },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(string));
}

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"instr",
new[] { source, value },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(int)),
_sqlExpressionFactory.Constant(0));
if (genericMethodDefinition.Equals(ArrayMethods.IndexOf))
{
return _sqlExpressionFactory.Subtract(
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(1));
}

if (genericMethodDefinition.Equals(ArrayMethods.IndexOfWithStartingPosition))
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
{
// NOTE: IndexOf Method with a starting position is not supported by SQLite
return null;
}
}

// See issue#16428
Expand Down Expand Up @@ -92,4 +94,23 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor

return null;
}

private SqlExpression GetInStrSqlFunctionExpression(SqlExpression source, SqlExpression valueToSearch)
{
var value = valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
[valueToSearch],
nullable: false,
argumentsPropagateNullability: [false],
typeof(string));

return _sqlExpressionFactory.Function(
"instr",
[source, value],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int));
}
}
36 changes: 36 additions & 0 deletions src/Shared/ArrayMethods.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore;

internal static class ArrayMethods
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
{
public static MethodInfo IndexOf { get; }

public static MethodInfo IndexOfWithStartingPosition { get; }

static ArrayMethods()
{
var arrayGenericMethods = typeof(Array)
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(m => m.IsGenericMethod)
.GroupBy(m => m.Name)
.ToDictionary(m => m.Key, l => l.ToList());

IndexOf = GetMethod(nameof(Array.IndexOf), 1, (t) =>
{
return [t[0].MakeArrayType(), t[0]];
});

IndexOfWithStartingPosition = GetMethod(nameof(Array.IndexOf), 1, (t) =>
{
return [t[0].MakeArrayType(), t[0], typeof(int)];
});

MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
=> arrayGenericMethods[name].Single(
mi => mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));
}
}
90 changes: 90 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6259,6 +6259,96 @@ public virtual Task Byte_array_filter_by_length_parameter(bool async)
ss => ss.Set<Squad>().Where(w => w.Banner != null && w.Banner.Length == someByteArr.Length));
}

#region Byte Array IndexOf

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_literal(bool async)
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast here shouldn't be needed, no?

Suggested change
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, 1) == 1),

Copy link
Author

@nikhil197 nikhil197 Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actually, it is needed. Without this it's picking non-generic version of the method IndexOf(Array arr, object value) (because 1 is an Int32 by default and I haven't specified the type argument on the IndexOf).

Do we want to support that too?

ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1) == 1)
);
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_parameter(bool async)
{
byte b = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_n_filter_by_index_of_literal(bool async)
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5) == 1)
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_lenght_n_filter_by_index_of_parameter(bool async)
{
byte b = 4;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1, 1) == 1)
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position_parameter(bool async)
{
byte b = 0;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b, startPos) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5, 1) == 1)
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_parameter(bool async)
{
byte b = 4;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b, startPos) == 0)
);
}

#endregion

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_bool_coming_from_optional_navigation(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class Squad
{
public Squad()
{
Members = new List<Gear>();
Members = [];
}

// non-auto generated key
Expand Down
Loading