Skip to content

Commit

Permalink
Optimize Ascii.Equals when widening (#87141)
Browse files Browse the repository at this point in the history
* Optimize Ascii.Equals when widening

* add BoundedMemory tests to ensure that boundaries are respected

---------

Co-authored-by: Adam Sitnik <[email protected]>
  • Loading branch information
BrennanConroy and adamsitnik authored Jul 6, 2023
1 parent b78345e commit bd63402
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
|| (typeof(TLeft) == typeof(byte) && typeof(TRight) == typeof(ushort))
|| (typeof(TLeft) == typeof(ushort) && typeof(TRight) == typeof(ushort)));

if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TRight>.Count)
if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TLeft>.Count)
{
for (nuint i = 0; i < length; ++i)
{
Expand All @@ -61,42 +61,34 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
}
}
}
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
{
ref TLeft currentLeftSearchSpace = ref left;
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count256);
ref TRight currentRightSearchSpace = ref right;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TRight>.Count);

Vector256<TRight> leftValues;
Vector256<TRight> rightValues;
// Add Vector256<TLeft>.Count because TLeft == TRight
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
Debug.Assert(Vector256<TLeft>.Count == Vector256<TRight>.Count
|| (typeof(TLoader) == typeof(WideningLoader) && Vector256<TLeft>.Count == Vector256<TRight>.Count * 2));
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256<TLeft>.Count);

// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
leftValues = TLoader.Load256(ref currentLeftSearchSpace);
rightValues = Vector256.LoadUnsafe(ref currentRightSearchSpace);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
{
return false;
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TRight>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count256);
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256<TLeft>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector256<TLeft>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector256<TRight>.Count != 0)
if (length % (uint)Vector256<TLeft>.Count != 0)
{
leftValues = TLoader.Load256(ref oneVectorAwayFromLeftEnd);
rightValues = Vector256.LoadUnsafe(ref oneVectorAwayFromRightEnd);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
{
return false;
}
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
}
else
Expand Down Expand Up @@ -363,6 +355,7 @@ private interface ILoader<TLeft, TRight>
static abstract nuint Count256 { get; }
static abstract Vector128<TRight> Load128(ref TLeft ptr);
static abstract Vector256<TRight> Load256(ref TLeft ptr);
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
}

private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
Expand All @@ -371,6 +364,21 @@ private interface ILoader<TLeft, TRight>
public static nuint Count256 => (uint)Vector256<T>.Count;
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref T left, ref T right)
{
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
}
}

private readonly struct WideningLoader : ILoader<byte, ushort>
Expand Down Expand Up @@ -403,6 +411,32 @@ public static Vector256<ushort> Load256(ref byte ptr)
(Vector128<ushort> lower, Vector128<ushort> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr));
return Vector256.Create(lower, upper);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);

Vector256<byte> leftNotWidened = Vector256.LoadUnsafe(ref utf8);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector256<ushort> leftLower, Vector256<ushort> leftUpper) = Vector256.Widen(leftNotWidened);
Vector256<ushort> right = Vector256.LoadUnsafe(ref utf16);
Vector256<ushort> rightNext = Vector256.LoadUnsafe(ref utf16, (uint)Vector256<ushort>.Count);

// A branchless version of "leftLower != right || leftUpper != rightNext"
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector256<ushort>.Zero)
{
return false;
}

return true;
}
}
}
}
62 changes: 57 additions & 5 deletions src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Intrinsics;
using Xunit;

namespace System.Text.Tests
{
public abstract class AsciiEqualityTests
public abstract class AsciiEqualityTests<TLeft, TRight>
where TLeft : unmanaged
where TRight : unmanaged
{
protected abstract bool Equals(string left, string right);
protected abstract bool EqualsIgnoreCase(string left, string right);
protected abstract bool Equals(byte[] left, byte[] right);
protected abstract bool EqualsIgnoreCase(byte[] left, byte[] right);
protected abstract bool Equals(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);
protected abstract bool EqualsIgnoreCase(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);

public static IEnumerable<object[]> ValidAsciiInputs
{
Expand Down Expand Up @@ -140,9 +145,32 @@ public void Equals_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
[MemberData(nameof(ContainingNonAsciiCharactersBuffers))]
public void EqualsIgnoreCase_EqualValues_ButNonAscii_ReturnsFalse(byte[] input)
=> Assert.False(EqualsIgnoreCase(input, input));

[Theory]
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.After)]
[InlineData(PoisonPagePlacement.After, PoisonPagePlacement.Before)]
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.After)]
[InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.Before)]
public void Boundaries_Are_Respected(PoisonPagePlacement leftPoison, PoisonPagePlacement rightPoison)
{
for (int size = 1; size < 129; size++)
{
using BoundedMemory<TLeft> left = BoundedMemory.Allocate<TLeft>(size, leftPoison);
using BoundedMemory<TRight> right = BoundedMemory.Allocate<TRight>(size, rightPoison);

left.Span.Fill(default);
right.Span.Fill(default);

left.MakeReadonly();
right.MakeReadonly();

Assert.True(Equals(left.Span, right.Span));
Assert.True(EqualsIgnoreCase(left.Span, right.Span));
}
}
}

public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests
public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests<byte, byte>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), Encoding.ASCII.GetBytes(right));
Expand All @@ -155,9 +183,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left, right);

protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests
public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests<byte, char>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(Encoding.ASCII.GetBytes(left), right);
Expand All @@ -170,9 +204,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left, right.Select(b => (char)b).ToArray());

protected override bool Equals(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<byte> left, ReadOnlySpan<char> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests
public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests<char, byte>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(left, Encoding.ASCII.GetBytes(right));
Expand All @@ -185,9 +225,15 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right);

protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<byte> right)
=> Ascii.EqualsIgnoreCase(left, right);
}

public class AsciiEqualityTests_Char_Char : AsciiEqualityTests
public class AsciiEqualityTests_Char_Char : AsciiEqualityTests<char, char>
{
protected override bool Equals(string left, string right)
=> Ascii.Equals(left, right);
Expand All @@ -200,5 +246,11 @@ protected override bool Equals(byte[] left, byte[] right)

protected override bool EqualsIgnoreCase(byte[] left, byte[] right)
=> Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right.Select(b => (char)b).ToArray());

protected override bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
=> Ascii.Equals(left, right);

protected override bool EqualsIgnoreCase(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
=> Ascii.EqualsIgnoreCase(left, right);
}
}

0 comments on commit bd63402

Please sign in to comment.