Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Ascii.Equals when widening #87141

Merged
merged 4 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
|| (typeof(TLeft) == typeof(byte) && typeof(TRight) == typeof(ushort))
|| (typeof(TLeft) == typeof(ushort) && typeof(TRight) == typeof(ushort)));

if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TRight>.Count)
if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TLeft>.Count)
{
for (nuint i = 0; i < length; ++i)
{
Expand All @@ -61,42 +61,34 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
}
}
}
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count256);
ref TRight currentRightSearchSpace = ref right;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TRight>.Count);

Vector256<TRight> leftValues;
Vector256<TRight> rightValues;
// Add Vector256<TLeft>.Count because TLeft == TRight
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
Debug.Assert(Vector256<TLeft>.Count == Vector256<TRight>.Count
|| (typeof(TLoader) == typeof(WideningLoader) && Vector256<TLeft>.Count == Vector256<TRight>.Count * 2));
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TLeft>.Count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding why this is valid. We're subtracting from the "right" search space the number of "left" elements in a vector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works because TLeft and TRight are either the same type, or we are in the widen case where Vector<TLeft> is twice the size of Vector<TRight> and the widen code will advance twice Vector<TRight>.Count which is equal to 1 Vector<TLeft>.Count.

But it is written in a confusing way. The whole TLoader abstraction helps with code sharing but makes this part kind of yucky. Maybe if the compare method advanced the pointers it would be better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. At a minimum a comment explaining would be helpful.


// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
leftValues = TLoader.Load256(ref currentLeftSearchSpace);
rightValues = Vector256.LoadUnsafe(ref currentRightSearchSpace);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
{
return false;
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TRight>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count256);
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TLeft>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector256<TLeft>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector256<TRight>.Count != 0)
if (length % (uint)Vector256<TLeft>.Count != 0)
{
leftValues = TLoader.Load256(ref oneVectorAwayFromLeftEnd);
rightValues = Vector256.LoadUnsafe(ref oneVectorAwayFromRightEnd);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
{
return false;
}
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
}
else
Expand Down Expand Up @@ -363,6 +355,7 @@ private interface ILoader<TLeft, TRight>
static abstract nuint Count256 { get; }
static abstract Vector128<TRight> Load128(ref TLeft ptr);
static abstract Vector256<TRight> Load256(ref TLeft ptr);
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
}

private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
Expand All @@ -371,6 +364,21 @@ private interface ILoader<TLeft, TRight>
public static nuint Count256 => (uint)Vector256<T>.Count;
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref T left, ref T right)
{
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
}
}

private readonly struct WideningLoader : ILoader<byte, ushort>
Expand Down Expand Up @@ -403,6 +411,32 @@ public static Vector256<ushort> Load256(ref byte ptr)
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
return Vector256.Create(lower, upper);
Comment on lines 411 to 412
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
return Vector256.Create(lower, upper);
return Vector256.WidenLower(Vector128.LoadUnsafe(ref ptr).ToVector256Unsafe());

This results in better codegen when Avx2 is available.

Copy link
Contributor

@xtqqczze xtqqczze Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggests a set of missing System.Runtime.Intrinsics.Vector256 APIs:

public static System.Runtime.Intrinsics.Vector256<ushort> Widen (System.Runtime.Intrinsics.Vector128<byte> source);
 public static System.Runtime.Intrinsics.Vector256<ushort> LoadWideningUnsafe (ref byte source);

Copy link
Contributor

@xtqqczze xtqqczze Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in better codegen when Avx2 is available.

Codegen on arm64 is pretty bad though, probably should wrap with Avx2.IsSupported.

}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);

Vector256<byte> leftNotWidened = Vector256.LoadUnsafe(ref utf8);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector256<ushort> leftLower, Vector256<ushort> leftUpper) = Vector256.Widen(leftNotWidened);
Vector256<ushort> right = Vector256.LoadUnsafe(ref utf16);
Vector256<ushort> rightNext = Vector256.LoadUnsafe(ref utf16, (uint)Vector256<ushort>.Count);

// A branchless version of "leftLower != right || leftUpper != rightNext"
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector256<ushort>.Zero)
{
return false;
}

return true;
}
}
}
}
62 changes: 57 additions & 5 deletions src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Intrinsics;
using Xunit;

namespace System.Text.Tests
{
public abstract class AsciiEqualityTests
public abstract class AsciiEqualityTests<TLeft, TRight>
where TLeft : unmanaged
where TRight : unmanaged
{
protected abstract bool Equals(string left, string right);
protected abstract bool EqualsIgnoreCase(string left, string right);
protected abstract bool Equals(byte[] left, byte[] right);
protected abstract bool EqualsIgnoreCase(byte[] left, byte[] right);
protected abstract bool Equals(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);
protected abstract bool EqualsIgnoreCase(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);

public static IEnumerable<object[]> ValidAsciiInputs
{
Expand Down Expand Up @@ -140,9 +145,32 @@ public void Equals_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
[MemberData(nameof(ContainingNonAsciiCharactersBuffers))]
public void EqualsIgnoreCase_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
=> Assert.False(EqualsIgnoreCase(input, input));

[Theory]
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.After)]
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.Before)]
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.After)]
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.Before)]
public void Boundaries_Are_Respected(PoisonPagePlacement leftPoison, PoisonPagePlacement rightPoison)
{
for (int size = 1; size < 129; size++)
{
using BoundedMemory<TLeft> left = BoundedMemory.Allocate<TLeft>(size, leftPoison);
using BoundedMemory<TRight> right = BoundedMemory.Allocate<TRight>(size, rightPoison);

left.Span.Fill(default);
right.Span.Fill(default);

left.MakeReadonly();
right.MakeReadonly();

Assert.True(Equals(left.Span, right.Span));
Assert.True(EqualsIgnoreCase(left.Span, right.Span));
}
}
}

public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests
public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests<byte, byte>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), Encoding.ASCII.GetBytes(right));
Expand All @@ -155,9 +183,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left, right);

protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests
public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests<byte, char>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), right);
Expand All @@ -170,9 +204,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left, right.Select(b => (char)b).ToArray());

protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests
public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests<char, byte>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(left, Encoding.ASCII.GetBytes(right));
Expand All @@ -185,9 +225,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right);

protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Char_Char : AsciiEqualityTests
public class AsciiEqualityTests_Char_Char : AsciiEqualityTests<char, char>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(left, right);
Expand All @@ -200,5 +246,11 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right.Select(b => (char)b).ToArray());

protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
=> Ascii.EqualsIgnoreCase(left, right);
}
}
Loading