diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs index 5a9e9ef09cfb1..a2d60599fae39 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs @@ -48,7 +48,7 @@ private static bool Equals(ref TLeft left, ref TRight ri || (typeof(TLeft) == typeof(byte) && typeof(TRight) == typeof(ushort)) || (typeof(TLeft) == typeof(ushort) && typeof(TRight) == typeof(ushort))); - if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128.Count) + if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128.Count) { for (nuint i = 0; i < length; ++i) { @@ -61,42 +61,34 @@ private static bool Equals(ref TLeft left, ref TRight ri } } } - else if (Avx.IsSupported && length >= (uint)Vector256.Count) + else if (Avx.IsSupported && length >= (uint)Vector256.Count) { ref TLeft currentLeftSearchSpace = ref left; - ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count256); ref TRight currentRightSearchSpace = ref right; - ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256.Count); - - Vector256 leftValues; - Vector256 rightValues; + // Add Vector256.Count because TLeft == TRight + // Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count + Debug.Assert(Vector256.Count == Vector256.Count + || (typeof(TLoader) == typeof(WideningLoader) && Vector256.Count == Vector256.Count * 2)); + ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector256.Count); // Loop until either we've finished all elements or there's less than a vector's-worth remaining. do { - leftValues = TLoader.Load256(ref currentLeftSearchSpace); - rightValues = Vector256.LoadUnsafe(ref currentRightSearchSpace); - - if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues)) + if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace)) { return false; } - currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256.Count); - currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count256); + currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector256.Count); + currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector256.Count); } while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd)); // If any elements remain, process the last vector in the search space. - if (length % (uint)Vector256.Count != 0) + if (length % (uint)Vector256.Count != 0) { - leftValues = TLoader.Load256(ref oneVectorAwayFromLeftEnd); - rightValues = Vector256.LoadUnsafe(ref oneVectorAwayFromRightEnd); - - if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues)) - { - return false; - } + ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256.Count); + return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd); } } else @@ -363,6 +355,7 @@ private interface ILoader static abstract nuint Count256 { get; } static abstract Vector128 Load128(ref TLeft ptr); static abstract Vector256 Load256(ref TLeft ptr); + static abstract bool EqualAndAscii(ref TLeft left, ref TRight right); } private readonly struct PlainLoader : ILoader where T : unmanaged, INumberBase @@ -371,6 +364,21 @@ private interface ILoader public static nuint Count256 => (uint)Vector256.Count; public static Vector128 Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr); public static Vector256 Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx))] + public static bool EqualAndAscii(ref T left, ref T right) + { + Vector256 leftValues = Vector256.LoadUnsafe(ref left); + Vector256 rightValues = Vector256.LoadUnsafe(ref right); + + if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues)) + { + return false; + } + + return true; + } } private readonly struct WideningLoader : ILoader @@ -403,6 +411,32 @@ public static Vector256 Load256(ref byte ptr) (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref ptr)); return Vector256.Create(lower, upper); } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CompExactlyDependsOn(typeof(Avx))] + public static bool EqualAndAscii(ref byte utf8, ref ushort utf16) + { + // We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search + Debug.Assert(Vector256.Count == Vector256.Count * 2); + + Vector256 leftNotWidened = Vector256.LoadUnsafe(ref utf8); + if (!AllCharsInVectorAreAscii(leftNotWidened)) + { + return false; + } + + (Vector256 leftLower, Vector256 leftUpper) = Vector256.Widen(leftNotWidened); + Vector256 right = Vector256.LoadUnsafe(ref utf16); + Vector256 rightNext = Vector256.LoadUnsafe(ref utf16, (uint)Vector256.Count); + + // A branchless version of "leftLower != right || leftUpper != rightNext" + if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector256.Zero) + { + return false; + } + + return true; + } } } } diff --git a/src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs b/src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs index c2186defc5e12..2ca033b46e555 100644 --- a/src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs +++ b/src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Runtime.Intrinsics; @@ -8,12 +9,16 @@ namespace System.Text.Tests { - public abstract class AsciiEqualityTests + public abstract class AsciiEqualityTests + where TLeft : unmanaged + where TRight : unmanaged { protected abstract bool Equals(string left, string right); protected abstract bool EqualsIgnoreCase(string left, string right); protected abstract bool Equals(byte[] left, byte[] right); protected abstract bool EqualsIgnoreCase(byte[] left, byte[] right); + protected abstract bool Equals(ReadOnlySpan left, ReadOnlySpan right); + protected abstract bool EqualsIgnoreCase(ReadOnlySpan left, ReadOnlySpan right); public static IEnumerable ValidAsciiInputs { @@ -140,9 +145,32 @@ public void Equals_EqualValues_ButNonAscii_ReturnsFalse(byte[] input) [MemberData(nameof(ContainingNonAsciiCharactersBuffers))] public void EqualsIgnoreCase_EqualValues_ButNonAscii_ReturnsFalse(byte[] input) => Assert.False(EqualsIgnoreCase(input, input)); + + [Theory] + [InlineData(PoisonPagePlacement.After, PoisonPagePlacement.After)] + [InlineData(PoisonPagePlacement.After, PoisonPagePlacement.Before)] + [InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.After)] + [InlineData(PoisonPagePlacement.Before, PoisonPagePlacement.Before)] + public void Boundaries_Are_Respected(PoisonPagePlacement leftPoison, PoisonPagePlacement rightPoison) + { + for (int size = 1; size < 129; size++) + { + using BoundedMemory left = BoundedMemory.Allocate(size, leftPoison); + using BoundedMemory right = BoundedMemory.Allocate(size, rightPoison); + + left.Span.Fill(default); + right.Span.Fill(default); + + left.MakeReadonly(); + right.MakeReadonly(); + + Assert.True(Equals(left.Span, right.Span)); + Assert.True(EqualsIgnoreCase(left.Span, right.Span)); + } + } } - public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests + public class AsciiEqualityTests_Byte_Byte : AsciiEqualityTests { protected override bool Equals(string left, string right) => Ascii.Equals(Encoding.ASCII.GetBytes(left), Encoding.ASCII.GetBytes(right)); @@ -155,9 +183,15 @@ protected override bool Equals(byte[] left, byte[] right) protected override bool EqualsIgnoreCase(byte[] left, byte[] right) => Ascii.EqualsIgnoreCase(left, right); + + protected override bool Equals(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.Equals(left, right); + + protected override bool EqualsIgnoreCase(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.EqualsIgnoreCase(left, right); } - public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests + public class AsciiEqualityTests_Byte_Char : AsciiEqualityTests { protected override bool Equals(string left, string right) => Ascii.Equals(Encoding.ASCII.GetBytes(left), right); @@ -170,9 +204,15 @@ protected override bool Equals(byte[] left, byte[] right) protected override bool EqualsIgnoreCase(byte[] left, byte[] right) => Ascii.EqualsIgnoreCase(left, right.Select(b => (char)b).ToArray()); + + protected override bool Equals(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.Equals(left, right); + + protected override bool EqualsIgnoreCase(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.EqualsIgnoreCase(left, right); } - public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests + public class AsciiEqualityTests_Char_Byte : AsciiEqualityTests { protected override bool Equals(string left, string right) => Ascii.Equals(left, Encoding.ASCII.GetBytes(right)); @@ -185,9 +225,15 @@ protected override bool Equals(byte[] left, byte[] right) protected override bool EqualsIgnoreCase(byte[] left, byte[] right) => Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right); + + protected override bool Equals(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.Equals(left, right); + + protected override bool EqualsIgnoreCase(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.EqualsIgnoreCase(left, right); } - public class AsciiEqualityTests_Char_Char : AsciiEqualityTests + public class AsciiEqualityTests_Char_Char : AsciiEqualityTests { protected override bool Equals(string left, string right) => Ascii.Equals(left, right); @@ -200,5 +246,11 @@ protected override bool Equals(byte[] left, byte[] right) protected override bool EqualsIgnoreCase(byte[] left, byte[] right) => Ascii.EqualsIgnoreCase(left.Select(b => (char)b).ToArray(), right.Select(b => (char)b).ToArray()); + + protected override bool Equals(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.Equals(left, right); + + protected override bool EqualsIgnoreCase(ReadOnlySpan left, ReadOnlySpan right) + => Ascii.EqualsIgnoreCase(left, right); } }