Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Ascii.Equals when widening #87141

Merged
merged 4 commits into from
Jul 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
|| (typeof(TLeft) == typeof(byte) && typeof(TRight) == typeof(ushort))
|| (typeof(TLeft) == typeof(ushort) && typeof(TRight) == typeof(ushort)));

if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TRight>.Count)
if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TLeft>.Count)
{
for (nuint i = 0; i < length; ++i)
{
Expand All @@ -61,42 +61,30 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
}
}
}
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count256);
ref TRight currentRightSearchSpace = ref right;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TRight>.Count);

Vector256<TRight> leftValues;
Vector256<TRight> rightValues;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TLeft>.Count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding why this is valid. We're subtracting from the "right" search space the number of "left" elements in a vector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works because TLeft and TRight are either the same type, or we are in the widen case where Vector<TLeft> is twice the size of Vector<TRight> and the widen code will advance twice Vector<TRight>.Count which is equal to 1 Vector<TLeft>.Count.

But it is written in a confusing way. The whole TLoader abstraction helps with code sharing but makes this part kind of yucky. Maybe if the compare method advanced the pointers it would be better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. At a minimum a comment explaining would be helpful.


// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
leftValues = TLoader.Load256(ref currentLeftSearchSpace);
rightValues = Vector256.LoadUnsafe(ref currentRightSearchSpace);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
if (!TLoader.Compare256(ref currentLeftSearchSpace, ref currentRightSearchSpace))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TLoader does now the comparison, so the name should be adjusted to reflect that.

adamsitnik marked this conversation as resolved.
Show resolved Hide resolved
{
return false;
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TRight>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count256);
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TLeft>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector256<TLeft>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector256<TRight>.Count != 0)
if (length % (uint)Vector256<TLeft>.Count != 0)
{
leftValues = TLoader.Load256(ref oneVectorAwayFromLeftEnd);
rightValues = Vector256.LoadUnsafe(ref oneVectorAwayFromRightEnd);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
{
return false;
}
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
return TLoader.Compare256(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
}
else
Expand Down Expand Up @@ -363,6 +351,7 @@ private interface ILoader<TLeft, TRight>
static abstract nuint Count256 { get; }
static abstract Vector128<TRight> Load128(ref TLeft ptr);
static abstract Vector256<TRight> Load256(ref TLeft ptr);
static abstract bool Compare256(ref TLeft left, ref TRight right);
adamsitnik marked this conversation as resolved.
Show resolved Hide resolved
}

private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
Expand All @@ -371,6 +360,21 @@ private interface ILoader<TLeft, TRight>
public static nuint Count256 => (uint)Vector256<T>.Count;
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool Compare256(ref T left, ref T right)
adamsitnik marked this conversation as resolved.
Show resolved Hide resolved
{
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
}
}

private readonly struct WideningLoader : ILoader<byte, ushort>
Expand Down Expand Up @@ -403,6 +407,31 @@ public static Vector256<ushort> Load256(ref byte ptr)
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
return Vector256.Create(lower, upper);
Comment on lines 411 to 412
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
return Vector256.Create(lower, upper);
return Vector256.WidenLower(Vector128.LoadUnsafe(ref ptr).ToVector256Unsafe());

This results in better codegen when Avx2 is available.

Copy link
Contributor

@xtqqczze xtqqczze Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggests a set of missing System.Runtime.Intrinsics.Vector256 APIs:

public static System.Runtime.Intrinsics.Vector256<ushort> Widen (System.Runtime.Intrinsics.Vector128<byte> source);
 public static System.Runtime.Intrinsics.Vector256<ushort> LoadWideningUnsafe (ref byte source);

Copy link
Contributor

@xtqqczze xtqqczze Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in better codegen when Avx2 is available.

Codegen on arm64 is pretty bad though, probably should wrap with Avx2.IsSupported.

}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool Compare256(ref byte left, ref ushort right)
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
{
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);

Vector256<byte> leftNotWidened = Vector256.LoadUnsafe(ref left);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector256<ushort> lower, Vector256<ushort> upper) = Vector256.Widen(leftNotWidened);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me adding value:

Suggested change
(Vector256<ushort> lower, Vector256<ushort> upper) = Vector256.Widen(leftNotWidened);
var (lower, upper) = Vector256.Widen(leftNotWidened);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me adding value:

Please stop adding value. :-P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BCL has a rule that var can only be used when the type is apparent.

It's a battle that I lost before ever joining the team 😅

Vector256<ushort> rightValues0 = Vector256.LoadUnsafe(ref right);
Vector256<ushort> rightValues1 = Vector256.LoadUnsafe(ref Unsafe.Add(ref right, (uint)Vector256<ushort>.Count));
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved

if (!Vector256<ushort>.AllBitsSet.Equals(
Vector256.BitwiseAnd(Vector256.Equals(lower, rightValues0), Vector256.Equals(upper, rightValues1))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should prefer the operators where possible:

Suggested change
if (!Vector256<ushort>.AllBitsSet.Equals(
Vector256.BitwiseAnd(Vector256.Equals(lower, rightValues0), Vector256.Equals(upper, rightValues1))))
if (Vector256<ushort>.AllBitsSet != (Vector256.Equals(lower, rightValues0) & Vector256.Equals(upper, rightValues1)))

bool Equals() in particular is the same as == for integral types, but isn't directly recognized as intrinsic and is an instance methjod. So the JIT has to inline and elide the reference taken for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to just make this return Vector256<ushort>.AllBitsSet == (Vector256.Equals(lower, rightValues0) & Vector256.Equals(upper, rightValues1)) instead as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the "Side-note on weird codegen observed" section at the bottom of the PR description.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be == if you noticed a suboptimal codegen - we'd better look and fix it in JIT instead of complicating C# code just to squeeze everything here and now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be as simple as:

return ((lower ^ rightValues0) | (upper ^ rightValues1)) == Vector256.Zero;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sometimes it can be fixed with return cond ? true : false;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

Vector256<ushort> equals = Vector256.Equals(lower, rightValues0) & Vector256.Equals(upper, rightValues1);
if (equals.AsByte().ExtractMostSignificantBits() != 0xffffffffu)
    return false;
return true;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

Not needed, ((lower ^ rightValues0) | (upper ^ rightValues1)) == Vector256.Zero; is just a canonical way to do that. Also, ExtractMostSignificantBits is very expensive on arm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EgorBo Agreed, VPMOVMSKB has no equivalent on ARM64 (see #87141 (comment)), but this method is guarded by [CompExactlyDependsOn(typeof(Avx))].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EgorBo Agreed, VPMOVMSKB has no equivalent on ARM64 (see #87141 (comment)), but this method is guarded by [CompExactlyDependsOn(typeof(Avx))].

still, == Vector.Zero is expected to be lowered to MoveMask with SSE2 or to vptest with SSE41/AVX so no reason to do it by hands

{
return false;
}

return true;
}
}
}
}