From ab08501f7627cfa0ade8d98dbab85a5badec84b5 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 5 Sep 2023 11:28:07 -0400 Subject: [PATCH 1/4] Start vectorizing TensorPrimitives Just does two functions to establish the files into which the rest of the implementations can be moved. --- .../src/System.Numerics.Tensors.csproj | 8 + .../Numerics/Tensors/TensorPrimitives.cs | 46 +- .../Tensors/TensorPrimitives.netcore.cs | 163 ++++++ .../Tensors/TensorPrimitives.netstandard.cs | 104 ++++ .../tests/TensorPrimitivesTests.cs | 551 ++++++++++-------- 5 files changed, 576 insertions(+), 296 deletions(-) create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 5c188c2e0b841..097fa244ad491 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -16,6 +16,14 @@ + + + + + + + + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 10f29183ea286..1b313bc55bc4d 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -4,52 +4,8 @@ namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. - public static class TensorPrimitives + public static partial class TensorPrimitives { - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] + [i]. - public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] + y[i]; - } - } - - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] + y; - } - } - /// Computes the element-wise result of: - . /// The first tensor, represented as a span. /// The second tensor, represented as a scalar. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs new file mode 100644 index 0000000000000..7bdfeb1e3ffa6 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -0,0 +1,163 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; + +namespace System.Numerics.Tensors +{ + /// Performs primitive tensor operations over spans of memory. + public static partial class TensorPrimitives + { + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] + [i]. + public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int remaining = x.Length; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && remaining >= Vector512.Count) + { + do + { + Vector512.StoreUnsafe(Vector512.LoadUnsafe(ref xRef) + Vector512.LoadUnsafe(ref yRef), ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector512.Count); + yRef = ref Unsafe.Add(ref yRef, Vector512.Count); + dRef = ref Unsafe.Add(ref dRef, Vector512.Count); + remaining -= Vector512.Count; + } + while (remaining >= Vector512.Count); + } +#endif + + if (Vector256.IsHardwareAccelerated && remaining >= Vector256.Count) + { + do + { + Vector256.StoreUnsafe(Vector256.LoadUnsafe(ref xRef) + Vector256.LoadUnsafe(ref yRef), ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector256.Count); + yRef = ref Unsafe.Add(ref yRef, Vector256.Count); + dRef = ref Unsafe.Add(ref dRef, Vector256.Count); + remaining -= Vector256.Count; + } + while (remaining >= Vector256.Count); + } + + if (Vector128.IsHardwareAccelerated && remaining >= Vector128.Count) + { + do + { + Vector128.StoreUnsafe(Vector128.LoadUnsafe(ref xRef) + Vector128.LoadUnsafe(ref yRef), ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector128.Count); + yRef = ref Unsafe.Add(ref yRef, Vector128.Count); + dRef = ref Unsafe.Add(ref dRef, Vector128.Count); + remaining -= Vector128.Count; + } + while (remaining >= Vector128.Count); + } + + while (remaining != 0) + { + dRef = xRef + yRef; + + xRef = ref Unsafe.Add(ref xRef, 1); + yRef = ref Unsafe.Add(ref yRef, 1); + dRef = ref Unsafe.Add(ref dRef, 1); + remaining--; + } + } + + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] + . + public static void Add(ReadOnlySpan x, float y, Span destination) + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int remaining = x.Length; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && remaining >= Vector512.Count) + { + Vector512 yVec = Vector512.Create(y); + do + { + Vector512.StoreUnsafe(Vector512.LoadUnsafe(ref xRef) + yVec, ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector512.Count); + dRef = ref Unsafe.Add(ref dRef, Vector512.Count); + remaining -= Vector512.Count; + } + while (remaining >= Vector512.Count); + } +#endif + + if (Vector256.IsHardwareAccelerated && remaining >= Vector256.Count) + { + Vector256 yVec = Vector256.Create(y); + do + { + Vector256.StoreUnsafe(Vector256.LoadUnsafe(ref xRef) + yVec, ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector256.Count); + dRef = ref Unsafe.Add(ref dRef, Vector256.Count); + remaining -= Vector256.Count; + } + while (remaining >= Vector256.Count); + } + + if (Vector128.IsHardwareAccelerated && remaining >= Vector128.Count) + { + Vector128 yVec = Vector128.Create(y); + do + { + Vector128.StoreUnsafe(Vector128.LoadUnsafe(ref xRef) + yVec, ref dRef); + + xRef = ref Unsafe.Add(ref xRef, Vector128.Count); + dRef = ref Unsafe.Add(ref dRef, Vector128.Count); + remaining -= Vector128.Count; + } + while (remaining >= Vector128.Count); + } + + while (remaining != 0) + { + dRef = xRef + y; + + xRef = ref Unsafe.Add(ref xRef, 1); + dRef = ref Unsafe.Add(ref dRef, 1); + remaining--; + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs new file mode 100644 index 0000000000000..3aa952ea3dde5 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.InteropServices; + +namespace System.Numerics.Tensors +{ + /// Performs primitive tensor operations over spans of memory. + public static unsafe partial class TensorPrimitives + { + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] + [i]. + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + fixed (float* xPtr = &MemoryMarshal.GetReference(x), yPtr = &MemoryMarshal.GetReference(y), destPtr = &MemoryMarshal.GetReference(destination)) + { + float* px = xPtr, py = yPtr, pd = destPtr; + int remaining = x.Length; + + if (Vector.IsHardwareAccelerated && remaining >= Vector.Count) + { + do + { + *(Vector*)pd = *(Vector*)px + *(Vector*)py; + + px += Vector.Count; + py += Vector.Count; + pd += Vector.Count; + remaining -= Vector.Count; + } + while (remaining >= Vector.Count); + } + + while (remaining != 0) + { + *pd = *px + *py; + + px++; + py++; + pd++; + remaining--; + } + } + } + + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] + . + public static void Add(ReadOnlySpan x, float y, Span destination) + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + fixed (float* xPtr = &MemoryMarshal.GetReference(x), destPtr = &MemoryMarshal.GetReference(destination)) + { + float* px = xPtr, pd = destPtr; + int remaining = x.Length; + + if (Vector.IsHardwareAccelerated && remaining >= Vector.Count) + { + Vector yVec = new Vector(y); + do + { + *(Vector*)pd = *(Vector*)px + yVec; + + px += Vector.Count; + pd += Vector.Count; + remaining -= Vector.Count; + } + while (remaining >= Vector.Count); + } + + while (remaining != 0) + { + *pd = *px + y; + + px++; + pd++; + remaining--; + } + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 1e751543831c3..a442df3d26af7 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -3,6 +3,7 @@ using Xunit; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; #pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 @@ -11,9 +12,9 @@ namespace System.Numerics.Tensors.Tests { public static class TensorPrimitivesTests { - private const int TensorSize = 512; - - private const int MismatchedTensorSize = 2; + public static IEnumerable TensorLengths => + from length in new[] { 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17, 31, 32, 33, 100 } + select new object[] { length }; private static readonly Random s_random = new Random(20230828); @@ -46,586 +47,634 @@ private static float NextSingle() #endif } - [Fact] - public static void AddTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]), destination[i]); } } - [Fact] - public static void AddTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void AddTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void AddTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y), destination[i]); } } - [Fact] - public static void AddTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void SubtractTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Subtract(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] - y[i]), destination[i]); } } - [Fact] - public static void SubtractTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void SubtractTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void SubtractTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Subtract(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] - y), destination[i]); } } - [Fact] - public static void SubtractTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void MultiplyTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Multiply(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]), destination[i]); } } - [Fact] - public static void MultiplyTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void MultiplyTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void MultiplyTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Multiply(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y), destination[i]); } } - [Fact] - public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void DivideTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Divide(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] / y[i]), destination[i]); } } - [Fact] - public static void DivideTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void DivideTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void DivideTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Divide(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] / y), destination[i]); } } - [Fact] - public static void DivideTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void NegateTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void NegateTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Negate(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(-x[i], destination[i]); } } - [Fact] - public static void NegateTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void NegateTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Negate(x, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]); } } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float multiplier = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]) * multiplier, destination[i]); } } - [Fact] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float multiplier = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTensorAndScalarAndMultiplyWithTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y) * multiplier[i], destination[i]); } } - [Fact] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]) + addend[i], destination[i]); } } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float addend = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]) + addend, destination[i]); } } - [Fact] - public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float addend = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTensorAndScalarAndAddWithTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y) + addend[i], destination[i]); } } - [Fact] - public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void ExpTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ExpTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Exp(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Exp(x[i]), destination[i]); } } - [Fact] - public static void ExpTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ExpTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Exp(x, destination)); } - [Fact] - public static void LogTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void LogTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Log(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Log(x[i]), destination[i]); } } - [Fact] - public static void LogTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void LogTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Log(x, destination)); } - [Fact] - public static void CoshTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CoshTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Cosh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Cosh(x[i]), destination[i]); } } - [Fact] - public static void CoshTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CoshTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Cosh(x, destination)); } - [Fact] - public static void SinhTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SinhTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Sinh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Sinh(x[i]), destination[i]); } } - [Fact] - public static void SinhTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SinhTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Sinh(x, destination)); } - [Fact] - public static void TanhTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void TanhTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Tanh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Tanh(x[i]), destination[i]); } } - [Fact] - public static void TanhTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void TanhTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Tanh(x, destination)); } From e4ceb72077c63832502d74ea2de392751134e64e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 5 Sep 2023 13:15:03 -0400 Subject: [PATCH 2/4] Address PR feedback --- .../Tensors/TensorPrimitives.netcore.cs | 126 +++++++++--------- .../Tensors/TensorPrimitives.netstandard.cs | 56 ++++---- 2 files changed, 89 insertions(+), 93 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 7bdfeb1e3ffa6..9ca70c63384ad 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -32,60 +32,51 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); ref float dRef = ref MemoryMarshal.GetReference(destination); - int remaining = x.Length; + int i = 0, oneVectorFromEnd; #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && remaining >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - do + oneVectorFromEnd = x.Length - Vector512.Count; + while (i <= oneVectorFromEnd) { - Vector512.StoreUnsafe(Vector512.LoadUnsafe(ref xRef) + Vector512.LoadUnsafe(ref yRef), ref dRef); + Vector512 sum = Vector512.LoadUnsafe(ref xRef, (uint)i) + Vector512.LoadUnsafe(ref yRef, (uint)i); + Vector512.StoreUnsafe(sum, ref dRef, (uint)i); - xRef = ref Unsafe.Add(ref xRef, Vector512.Count); - yRef = ref Unsafe.Add(ref yRef, Vector512.Count); - dRef = ref Unsafe.Add(ref dRef, Vector512.Count); - remaining -= Vector512.Count; + i += Vector512.Count; } - while (remaining >= Vector512.Count); } #endif - if (Vector256.IsHardwareAccelerated && remaining >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - do + oneVectorFromEnd = x.Length - Vector256.Count; + while (i <= oneVectorFromEnd) { - Vector256.StoreUnsafe(Vector256.LoadUnsafe(ref xRef) + Vector256.LoadUnsafe(ref yRef), ref dRef); + Vector256 sum = Vector256.LoadUnsafe(ref xRef, (uint)i) + Vector256.LoadUnsafe(ref yRef, (uint)i); + Vector256.StoreUnsafe(sum, ref dRef, (uint)i); - xRef = ref Unsafe.Add(ref xRef, Vector256.Count); - yRef = ref Unsafe.Add(ref yRef, Vector256.Count); - dRef = ref Unsafe.Add(ref dRef, Vector256.Count); - remaining -= Vector256.Count; + i += Vector256.Count; } - while (remaining >= Vector256.Count); } - if (Vector128.IsHardwareAccelerated && remaining >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - do + oneVectorFromEnd = x.Length - Vector128.Count; + while (i <= oneVectorFromEnd) { - Vector128.StoreUnsafe(Vector128.LoadUnsafe(ref xRef) + Vector128.LoadUnsafe(ref yRef), ref dRef); + Vector128 sum = Vector128.LoadUnsafe(ref xRef, (uint)i) + Vector128.LoadUnsafe(ref yRef, (uint)i); + Vector128.StoreUnsafe(sum, ref dRef, (uint)i); - xRef = ref Unsafe.Add(ref xRef, Vector128.Count); - yRef = ref Unsafe.Add(ref yRef, Vector128.Count); - dRef = ref Unsafe.Add(ref dRef, Vector128.Count); - remaining -= Vector128.Count; + i += Vector128.Count; } - while (remaining >= Vector128.Count); } - while (remaining != 0) + while (i < x.Length) { - dRef = xRef + yRef; + Unsafe.Add(ref dRef, i) = Unsafe.Add(ref xRef, i) + Unsafe.Add(ref yRef, i); - xRef = ref Unsafe.Add(ref xRef, 1); - yRef = ref Unsafe.Add(ref yRef, 1); - dRef = ref Unsafe.Add(ref dRef, 1); - remaining--; + i++; } } @@ -104,59 +95,66 @@ public static void Add(ReadOnlySpan x, float y, Span destination) ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); - int remaining = x.Length; + int i = 0, oneVectorFromEnd; #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && remaining >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - Vector512 yVec = Vector512.Create(y); - do + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) { - Vector512.StoreUnsafe(Vector512.LoadUnsafe(ref xRef) + yVec, ref dRef); - - xRef = ref Unsafe.Add(ref xRef, Vector512.Count); - dRef = ref Unsafe.Add(ref dRef, Vector512.Count); - remaining -= Vector512.Count; + Vector512 yVec = Vector512.Create(y); + do + { + Vector512 sum = Vector512.LoadUnsafe(ref xRef, (uint)i) + yVec; + Vector512.StoreUnsafe(sum, ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); } - while (remaining >= Vector512.Count); } #endif - if (Vector256.IsHardwareAccelerated && remaining >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - Vector256 yVec = Vector256.Create(y); - do + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) { - Vector256.StoreUnsafe(Vector256.LoadUnsafe(ref xRef) + yVec, ref dRef); - - xRef = ref Unsafe.Add(ref xRef, Vector256.Count); - dRef = ref Unsafe.Add(ref dRef, Vector256.Count); - remaining -= Vector256.Count; + Vector256 yVec = Vector256.Create(y); + do + { + Vector256 sum = Vector256.LoadUnsafe(ref xRef, (uint)i) + yVec; + Vector256.StoreUnsafe(sum, ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); } - while (remaining >= Vector256.Count); } - if (Vector128.IsHardwareAccelerated && remaining >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - Vector128 yVec = Vector128.Create(y); - do + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) { - Vector128.StoreUnsafe(Vector128.LoadUnsafe(ref xRef) + yVec, ref dRef); - - xRef = ref Unsafe.Add(ref xRef, Vector128.Count); - dRef = ref Unsafe.Add(ref dRef, Vector128.Count); - remaining -= Vector128.Count; + Vector128 yVec = Vector128.Create(y); + do + { + Vector128 sum = Vector128.LoadUnsafe(ref xRef, (uint)i) + yVec; + Vector128.StoreUnsafe(sum, ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); } - while (remaining >= Vector128.Count); } - while (remaining != 0) + while (i < x.Length) { - dRef = xRef + y; + Unsafe.Add(ref dRef, i) = Unsafe.Add(ref xRef, i) + y; - xRef = ref Unsafe.Add(ref xRef, 1); - dRef = ref Unsafe.Add(ref dRef, 1); - remaining--; + i++; } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 3aa952ea3dde5..3ed81a368e4b9 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -30,30 +30,28 @@ public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span fixed (float* xPtr = &MemoryMarshal.GetReference(x), yPtr = &MemoryMarshal.GetReference(y), destPtr = &MemoryMarshal.GetReference(destination)) { float* px = xPtr, py = yPtr, pd = destPtr; - int remaining = x.Length; + int i = 0, oneVectorFromEnd; - if (Vector.IsHardwareAccelerated && remaining >= Vector.Count) + if (Vector.IsHardwareAccelerated) { - do + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - *(Vector*)pd = *(Vector*)px + *(Vector*)py; + do + { + *(Vector*)(pd + i) = *(Vector*)(px + i) + *(Vector*)(py + i); - px += Vector.Count; - py += Vector.Count; - pd += Vector.Count; - remaining -= Vector.Count; + i += Vector.Count; + } + while (i <= oneVectorFromEnd); } - while (remaining >= Vector.Count); } - while (remaining != 0) + while (i < x.Length) { - *pd = *px + *py; + *(pd + i) = *(px + i) + *(py + i); - px++; - py++; - pd++; - remaining--; + i++; } } } @@ -74,29 +72,29 @@ public static void Add(ReadOnlySpan x, float y, Span destination) fixed (float* xPtr = &MemoryMarshal.GetReference(x), destPtr = &MemoryMarshal.GetReference(destination)) { float* px = xPtr, pd = destPtr; - int remaining = x.Length; + int i = 0, oneVectorFromEnd; - if (Vector.IsHardwareAccelerated && remaining >= Vector.Count) + if (Vector.IsHardwareAccelerated) { - Vector yVec = new Vector(y); - do + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - *(Vector*)pd = *(Vector*)px + yVec; + Vector yVec = new Vector(y); + do + { + *(Vector*)(pd + i) = *(Vector*)(px + i) + yVec; - px += Vector.Count; - pd += Vector.Count; - remaining -= Vector.Count; + i += Vector.Count; + } + while (i <= oneVectorFromEnd); } - while (remaining >= Vector.Count); } - while (remaining != 0) + while (i < x.Length) { - *pd = *px + y; + *(pd + i) = *(px + i) + y; - px++; - pd++; - remaining--; + i++; } } } From 06f9760ecacbcb0709af0bf613544dceb8d12f0c Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 5 Sep 2023 15:36:13 -0400 Subject: [PATCH 3/4] Refactor to using shared implementations for operators --- .../src/Resources/Strings.resx | 4 +- .../Numerics/Tensors/TensorPrimitives.cs | 324 ------ .../Tensors/TensorPrimitives.netcore.cs | 959 +++++++++++++++++- .../Tensors/TensorPrimitives.netstandard.cs | 578 ++++++++++- .../src/System/ThrowHelper.cs | 4 +- 5 files changed, 1470 insertions(+), 399 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index f9f7f45078a04..d70054a241126 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -121,6 +121,6 @@ Destination is too short. - Length of '{0}' must be same as length of '{1}'. + Input span arguments must all have the same length. - \ No newline at end of file + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 1b313bc55bc4d..0b69218338c3c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -6,330 +6,6 @@ namespace System.Numerics.Tensors /// Performs primitive tensor operations over spans of memory. public static partial class TensorPrimitives { - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] - y[i]; - } - } - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] - y; - } - } - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] * y[i]; - } - } - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// - /// This method effectively does [i] = [i] * . - /// This method corresponds to the scal method defined by BLAS1. - /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] * y; - } - } - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] / y[i]; - } - } - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] / y; - } - } - - /// Computes the element-wise result of: -. - /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = -x[i]; - } - } - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length != multiplier.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(multiplier)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y[i]) * multiplier[i]; - } - } - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y[i]) * multiplier; - } - } - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) - { - if (x.Length != multiplier.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(multiplier)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y) * multiplier[i]; - } - } - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length != addend.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(addend)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y[i]) + addend[i]; - } - } - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// - /// This method effectively does [i] = ([i] * [i]) + . - /// This method corresponds to the axpy method defined by BLAS1. - /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y[i]) + addend; - } - } - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) - { - if (x.Length != addend.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(addend)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y) + addend[i]; - } - } - /// Computes the element-wise result of: pow(e, ). /// The tensor, represented as a span. /// The destination tensor, represented as a span. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 9ca70c63384ad..81b69b6c1c4e7 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -17,11 +17,281 @@ public static partial class TensorPrimitives /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] + [i]. - public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) + public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] + . + public static void Add(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] - [i]. + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] - . + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] * . + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// + /// This method effectively does [i] = [i] * . + /// This method corresponds to the scal method defined by BLAS1. + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: -. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = -[i]. + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * [i]. + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * . + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + ) * [i]. + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * [i]) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// + /// This method effectively does [i] = ([i] * [i]) + . + /// This method corresponds to the axpy method defined by BLAS1. + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * ) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + + private static unsafe void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector512.StoreUnsafe( + TUnaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TUnaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector256.StoreUnsafe( + TUnaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TUnaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector128.StoreUnsafe( + TUnaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TUnaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke( + Unsafe.Add(ref xRef, i)); + + i++; + } + } + + private static unsafe void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : IBinaryOperator { if (x.Length != y.Length) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } if (x.Length > destination.Length) @@ -38,12 +308,33 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span if (Vector512.IsHardwareAccelerated) { oneVectorFromEnd = x.Length - Vector512.Count; - while (i <= oneVectorFromEnd) + if (i <= oneVectorFromEnd) { - Vector512 sum = Vector512.LoadUnsafe(ref xRef, (uint)i) + Vector512.LoadUnsafe(ref yRef, (uint)i); - Vector512.StoreUnsafe(sum, ref dRef, (uint)i); + // Loop handling one vector at a time. + do + { + Vector512.StoreUnsafe( + TBinaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TBinaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } - i += Vector512.Count; + return; } } #endif @@ -51,49 +342,502 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span if (Vector256.IsHardwareAccelerated) { oneVectorFromEnd = x.Length - Vector256.Count; - while (i <= oneVectorFromEnd) + if (i <= oneVectorFromEnd) { - Vector256 sum = Vector256.LoadUnsafe(ref xRef, (uint)i) + Vector256.LoadUnsafe(ref yRef, (uint)i); - Vector256.StoreUnsafe(sum, ref dRef, (uint)i); + // Loop handling one vector at a time. + do + { + Vector256.StoreUnsafe( + TBinaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); - i += Vector256.Count; + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TBinaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; } } if (Vector128.IsHardwareAccelerated) { oneVectorFromEnd = x.Length - Vector128.Count; - while (i <= oneVectorFromEnd) + if (i <= oneVectorFromEnd) { - Vector128 sum = Vector128.LoadUnsafe(ref xRef, (uint)i) + Vector128.LoadUnsafe(ref yRef, (uint)i); - Vector128.StoreUnsafe(sum, ref dRef, (uint)i); + // Loop handling one vector at a time. + do + { + Vector128.StoreUnsafe( + TBinaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i)), + ref dRef, (uint)i); - i += Vector128.Count; + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TBinaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; } } while (i < x.Length) { - Unsafe.Add(ref dRef, i) = Unsafe.Add(ref xRef, i) + Unsafe.Add(ref yRef, i); + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), Unsafe.Add(ref yRef, i)); i++; } } - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + Vector512 yVec = Vector512.Create(y); + + // Loop handling one vector at a time. + do + { + Vector512.StoreUnsafe( + TBinaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec), + ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TBinaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec), + ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + Vector256 yVec = Vector256.Create(y); + + // Loop handling one vector at a time. + do + { + Vector256.StoreUnsafe( + TBinaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec), + ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TBinaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec), + ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + Vector128 yVec = Vector128.Create(y); + + // Loop handling one vector at a time. + do + { + Vector128.StoreUnsafe( + TBinaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec), + ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TBinaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec), + ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), y); + + i++; + } + } + + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + Vector512.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + Vector256.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + Vector128.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : ITernaryOperator { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + if (x.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + Vector512 zVec = Vector512.Create(z); + + // Loop handling one vector at a time. + do + { + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + zVec), + ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + zVec), + ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + Vector256 zVec = Vector256.Create(z); + + // Loop handling one vector at a time. + do + { + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + zVec), + ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + zVec), + ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + Vector128 zVec = Vector128.Create(z); + + // Loop handling one vector at a time. + do + { + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + zVec), + ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + zVec), + ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + + i++; + } + } + + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); ref float dRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; @@ -104,14 +848,34 @@ public static void Add(ReadOnlySpan x, float y, Span destination) if (i <= oneVectorFromEnd) { Vector512 yVec = Vector512.Create(y); + + // Loop handling one vector at a time. do { - Vector512 sum = Vector512.LoadUnsafe(ref xRef, (uint)i) + yVec; - Vector512.StoreUnsafe(sum, ref dRef, (uint)i); + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector512.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); i += Vector512.Count; } while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + Vector512.StoreUnsafe( + TTernaryOperator.Invoke( + Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; } } #endif @@ -122,14 +886,34 @@ public static void Add(ReadOnlySpan x, float y, Span destination) if (i <= oneVectorFromEnd) { Vector256 yVec = Vector256.Create(y); + + // Loop handling one vector at a time. do { - Vector256 sum = Vector256.LoadUnsafe(ref xRef, (uint)i) + yVec; - Vector256.StoreUnsafe(sum, ref dRef, (uint)i); + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector256.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); i += Vector256.Count; } while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + Vector256.StoreUnsafe( + TTernaryOperator.Invoke( + Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; } } @@ -139,23 +923,140 @@ public static void Add(ReadOnlySpan x, float y, Span destination) if (i <= oneVectorFromEnd) { Vector128 yVec = Vector128.Create(y); + + // Loop handling one vector at a time. do { - Vector128 sum = Vector128.LoadUnsafe(ref xRef, (uint)i) + yVec; - Vector128.StoreUnsafe(sum, ref dRef, (uint)i); + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector128.LoadUnsafe(ref zRef, (uint)i)), + ref dRef, (uint)i); i += Vector128.Count; } while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + Vector128.StoreUnsafe( + TTernaryOperator.Invoke( + Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)), + ref dRef, lastVectorIndex); + } + + return; } } while (i < x.Length) { - Unsafe.Add(ref dRef, i) = Unsafe.Add(ref xRef, i) + y; + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( + Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); i++; } } + + private readonly struct AddOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; +#endif + } + + private readonly struct SubtractOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; +#endif + } + + private readonly struct MultiplyOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; +#endif + } + + private readonly struct DivideOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; +#endif + } + + private readonly struct NegateOperator : IUnaryOperator + { + public static float Invoke(float x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; + public static Vector512 Invoke(Vector512 x) => -x; + } + + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; + } + + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; + } + + private interface IUnaryOperator + { + static abstract float Invoke(float x); + static abstract Vector128 Invoke(Vector128 x); + static abstract Vector256 Invoke(Vector256 x); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x); +#endif + } + + private interface IBinaryOperator + { + static abstract float Invoke(float x, float y); + static abstract Vector128 Invoke(Vector128 x, Vector128 y); + static abstract Vector256 Invoke(Vector256 x, Vector256 y); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y); +#endif + } + + private interface ITernaryOperator + { + static abstract float Invoke(float x, float y, float z); + static abstract Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z); + static abstract Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z); +#endif + } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 3ed81a368e4b9..b10984be641bd 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace System.Numerics.Tensors @@ -15,11 +16,215 @@ public static unsafe partial class TensorPrimitives /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] + [i]. - public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination, default(AddOperator)); + + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] + . + public static void Add(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination, default(AddOperator)); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] - [i]. + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination, default(SubtractOperator)); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] - . + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination, default(SubtractOperator)); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] * . + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination, default(MultiplyOperator)); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// + /// This method effectively does [i] = [i] * . + /// This method corresponds to the scal method defined by BLAS1. + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination, default(MultiplyOperator)); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination, default(DivideOperator)); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination, default(DivideOperator)); + + /// Computes the element-wise result of: -. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = -[i]. + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination, default(NegateOperator)); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * [i]. + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * . + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + ) * [i]. + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * [i]) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// + /// This method effectively does [i] = ([i] * [i]) + . + /// This method corresponds to the axpy method defined by BLAS1. + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * ) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); + + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination, TUnaryOperator op) + where TUnaryOperator : IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + do + { + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i))); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex))); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i)); + + i++; + } + } + + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op) + where TBinaryOperator : IBinaryOperator { if (x.Length != y.Length) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } if (x.Length > destination.Length) @@ -27,76 +232,365 @@ public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span ThrowHelper.ThrowArgument_DestinationTooShort(); } - fixed (float* xPtr = &MemoryMarshal.GetReference(x), yPtr = &MemoryMarshal.GetReference(y), destPtr = &MemoryMarshal.GetReference(destination)) - { - float* px = xPtr, py = yPtr, pd = destPtr; - int i = 0, oneVectorFromEnd; + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; - if (Vector.IsHardwareAccelerated) + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + // Loop handling one vector at a time. + do { - do - { - *(Vector*)(pd + i) = *(Vector*)(px + i) + *(Vector*)(py + i); + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i)), + Unsafe.As>(ref Unsafe.Add(ref yRef, i))); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), + Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex))); } + + return; } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = + op.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + + i++; + } + } - while (i < x.Length) + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TBinaryOperator op) + where TBinaryOperator : IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - *(pd + i) = *(px + i) + *(py + i); + // Loop handling one vector at a time. + Vector yVec = new Vector(y); + do + { + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i)), + yVec); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); - i++; + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), + yVec); + } + + return; } } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = + op.Invoke( + Unsafe.Add(ref xRef, i), + y); + + i++; + } } - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op) + where TTernaryOperator : ITernaryOperator { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + if (x.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } - fixed (float* xPtr = &MemoryMarshal.GetReference(x), destPtr = &MemoryMarshal.GetReference(destination)) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + do + { + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i)), + Unsafe.As>(ref Unsafe.Add(ref yRef, i)), + Unsafe.As>(ref Unsafe.Add(ref zRef, i))); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), + Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex)), + Unsafe.As>(ref Unsafe.Add(ref zRef, lastVectorIndex))); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) { - float* px = xPtr, pd = destPtr; - int i = 0, oneVectorFromEnd; + Unsafe.Add(ref dRef, i) = op.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); - if (Vector.IsHardwareAccelerated) + i++; + } + } + + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + Vector zVec = new Vector(z); + + // Loop handling one vector at a time. + do { - Vector yVec = new Vector(y); - do - { - *(Vector*)(pd + i) = *(Vector*)(px + i) + yVec; - - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i)), + Unsafe.As>(ref Unsafe.Add(ref yRef, i)), + zVec); + + i += Vector.Count; } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), + Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex)), + zVec); + } + + return; } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + + i++; + } + } + + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } - while (i < x.Length) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) { - *(pd + i) = *(px + i) + y; + Vector yVec = new Vector(y); + + // Loop handling one vector at a time. + do + { + Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, i)), + yVec, + Unsafe.As>(ref Unsafe.Add(ref zRef, i))); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = + op.Invoke( + Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), + yVec, + Unsafe.As>(ref Unsafe.Add(ref zRef, lastVectorIndex))); + } - i++; + return; } } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke( + Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + private readonly struct AddOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x + y; + public Vector Invoke(Vector x, Vector y) => x + y; + } + + private readonly struct SubtractOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x - y; + public Vector Invoke(Vector x, Vector y) => x - y; + } + + private readonly struct MultiplyOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x * y; + public Vector Invoke(Vector x, Vector y) => x * y; + } + + private readonly struct DivideOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x / y; + public Vector Invoke(Vector x, Vector y) => x / y; + } + + private readonly struct NegateOperator : IUnaryOperator + { + public float Invoke(float x) => -x; + public Vector Invoke(Vector x) => -x; + } + + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x + y) * z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; + } + + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x * y) + z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; + } + + private interface IUnaryOperator + { + float Invoke(float x); + Vector Invoke(Vector x); + } + + private interface IBinaryOperator + { + float Invoke(float x, float y); + Vector Invoke(Vector x, Vector y); + } + + private interface ITernaryOperator + { + float Invoke(float x, float y, float z); + Vector Invoke(Vector x, Vector y, Vector z); } } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 5ae36ffbed768..ebe79895196ac 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -26,7 +26,7 @@ internal static class ThrowHelper public static void ThrowArgument_DestinationTooShort() => throw new ArgumentException(SR.Argument_DestinationTooShort); [DoesNotReturn] - public static void ThrowArgument_SpansMustHaveSameLength(string paramName1, string paramName2) - => throw new ArgumentException(SR.Format(SR.Argument_SpansMustHaveSameLength, paramName1, paramName2), paramName1); + public static void ThrowArgument_SpansMustHaveSameLength() + => throw new ArgumentException(SR.Argument_SpansMustHaveSameLength); } } From 6848d2196939898233650ad8d589d74cdb431d28 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 6 Sep 2023 21:49:42 -0400 Subject: [PATCH 4/4] Cleanup - Share public methods between implementations - Reduce boilerplate code in netstandard impl with AsVector helper - Align arguments - Add some missing tests based on code coverage - Add argument name to throw helper --- .../Numerics/Tensors/TensorPrimitives.cs | 158 ++++++ .../Tensors/TensorPrimitives.netcore.cs | 465 ++++-------------- .../Tensors/TensorPrimitives.netstandard.cs | 297 +++-------- .../src/System/ThrowHelper.cs | 22 +- .../tests/TensorPrimitivesTests.cs | 64 ++- 5 files changed, 360 insertions(+), 646 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 0b69218338c3c..08bd9d362217e 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -6,6 +6,164 @@ namespace System.Numerics.Tensors /// Performs primitive tensor operations over spans of memory. public static partial class TensorPrimitives { + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] + [i]. + public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] + . + public static void Add(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] - [i]. + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: - . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] - . + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] * . + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// + /// This method effectively does [i] = [i] * . + /// This method corresponds to the scal method defined by BLAS1. + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise result of: / . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = [i] / . + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of: -. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// This method effectively does [i] = -[i]. + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * [i]. + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + [i]) * . + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( + ) * . + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] + ) * [i]. + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * [i]) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// + /// This method effectively does [i] = ([i] * [i]) + . + /// This method corresponds to the axpy method defined by BLAS1. + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of: ( * ) + . + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of '' must be same as length of ''. + /// Destination is too short. + /// This method effectively does [i] = ([i] * ) + [i]. + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + /// Computes the element-wise result of: pow(e, ). /// The tensor, represented as a span. /// The destination tensor, represented as a span. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 81b69b6c1c4e7..1233f54901c80 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -7,167 +7,8 @@ namespace System.Numerics.Tensors { - /// Performs primitive tensor operations over spans of memory. public static partial class TensorPrimitives { - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] + [i]. - public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// - /// This method effectively does [i] = [i] * . - /// This method corresponds to the scal method defined by BLAS1. - /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: -. - /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// - /// This method effectively does [i] = ([i] * [i]) + . - /// This method corresponds to the axpy method defined by BLAS1. - /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); - private static unsafe void InvokeSpanIntoSpan( ReadOnlySpan x, Span destination) where TUnaryOperator : IUnaryOperator @@ -190,10 +31,7 @@ private static unsafe void InvokeSpanIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TUnaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i)), - ref dRef, (uint)i); + TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -203,10 +41,7 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TUnaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -222,10 +57,7 @@ private static unsafe void InvokeSpanIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TUnaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i)), - ref dRef, (uint)i); + TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -235,10 +67,7 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TUnaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -253,10 +82,7 @@ private static unsafe void InvokeSpanIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TUnaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i)), - ref dRef, (uint)i); + TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -266,10 +92,7 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TUnaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -278,8 +101,7 @@ private static unsafe void InvokeSpanIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke( - Unsafe.Add(ref xRef, i)); + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); i++; } @@ -313,11 +135,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TBinaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i)), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -327,11 +146,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TBinaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -347,11 +163,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TBinaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i)), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -361,11 +174,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TBinaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -380,11 +190,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TBinaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i)), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -394,11 +201,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TBinaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -407,7 +211,8 @@ private static unsafe void InvokeSpanSpanIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), Unsafe.Add(ref yRef, i)); + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); i++; } @@ -437,11 +242,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TBinaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -451,11 +253,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TBinaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -473,11 +272,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TBinaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -487,11 +283,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TBinaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -508,11 +301,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TBinaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec), - ref dRef, (uint)i); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -522,11 +312,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TBinaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec), - ref dRef, lastVectorIndex); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -535,7 +322,8 @@ private static unsafe void InvokeSpanScalarIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), y); + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y); i++; } @@ -570,12 +358,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - Vector512.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -585,12 +370,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -606,12 +388,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - Vector256.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -621,12 +400,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -641,12 +417,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - Vector128.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -656,12 +429,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -670,10 +440,9 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); i++; } @@ -709,12 +478,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - zVec), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -724,12 +490,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -747,12 +510,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - zVec), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -762,12 +522,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -784,12 +541,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - zVec), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -799,12 +553,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -813,10 +564,9 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); i++; } @@ -852,12 +602,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( // Loop handling one vector at a time. do { - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector512.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector512.Count; } @@ -867,12 +614,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.StoreUnsafe( - TTernaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -890,12 +634,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( // Loop handling one vector at a time. do { - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector256.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector256.Count; } @@ -905,12 +646,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.StoreUnsafe( - TTernaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -927,12 +665,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( // Loop handling one vector at a time. do { - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector128.LoadUnsafe(ref zRef, (uint)i)), - ref dRef, (uint)i); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); i += Vector128.Count; } @@ -942,12 +677,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.StoreUnsafe( - TTernaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)), - ref dRef, lastVectorIndex); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -956,10 +688,9 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke( - Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); i++; } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index b10984be641bd..ddac0f47a685c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -6,169 +6,10 @@ namespace System.Numerics.Tensors { - /// Performs primitive tensor operations over spans of memory. - public static unsafe partial class TensorPrimitives + public static partial class TensorPrimitives { - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] + [i]. - public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination, default(AddOperator)); - - /// Computes the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination, default(AddOperator)); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination, default(SubtractOperator)); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination, default(SubtractOperator)); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination, default(MultiplyOperator)); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// - /// This method effectively does [i] = [i] * . - /// This method corresponds to the scal method defined by BLAS1. - /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination, default(MultiplyOperator)); - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination, default(DivideOperator)); - - /// Computes the element-wise result of: / . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination, default(DivideOperator)); - - /// Computes the element-wise result of: -. - /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination, default(NegateOperator)); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); - - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination, default(AddMultiplyOperator)); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// - /// This method effectively does [i] = ([i] * [i]) + . - /// This method corresponds to the axpy method defined by BLAS1. - /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); - - /// Computes the element-wise result of: ( * ) + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, addend, destination, default(MultiplyAddOperator)); - private static void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination, TUnaryOperator op) + ReadOnlySpan x, Span destination, TUnaryOperator op = default) where TUnaryOperator : IUnaryOperator { if (x.Length > destination.Length) @@ -188,9 +29,7 @@ private static void InvokeSpanIntoSpan( // Loop handling one vector at a time. do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i))); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i)); i += Vector.Count; } @@ -200,9 +39,7 @@ private static void InvokeSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex))); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex)); } return; @@ -219,7 +56,7 @@ private static void InvokeSpanIntoSpan( } private static void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op) + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) where TBinaryOperator : IBinaryOperator { if (x.Length != y.Length) @@ -245,10 +82,8 @@ private static void InvokeSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i)), - Unsafe.As>(ref Unsafe.Add(ref yRef, i))); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i)); i += Vector.Count; } @@ -258,10 +93,8 @@ private static void InvokeSpanSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), - Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex))); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex)); } return; @@ -270,17 +103,15 @@ private static void InvokeSpanSpanIntoSpan( while (i < x.Length) { - Unsafe.Add(ref dRef, i) = - op.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); i++; } } private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination, TBinaryOperator op) + ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) where TBinaryOperator : IBinaryOperator { if (x.Length > destination.Length) @@ -298,13 +129,11 @@ private static void InvokeSpanScalarIntoSpan( if (oneVectorFromEnd >= 0) { // Loop handling one vector at a time. - Vector yVec = new Vector(y); + Vector yVec = new(y); do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i)), - yVec); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + yVec); i += Vector.Count; } @@ -314,10 +143,8 @@ private static void InvokeSpanScalarIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), - yVec); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + yVec); } return; @@ -327,17 +154,15 @@ private static void InvokeSpanScalarIntoSpan( // Loop handling one element at a time. while (i < x.Length) { - Unsafe.Add(ref dRef, i) = - op.Invoke( - Unsafe.Add(ref xRef, i), - y); + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + y); i++; } } private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op) + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) where TTernaryOperator : ITernaryOperator { if (x.Length != y.Length || x.Length != z.Length) @@ -364,11 +189,9 @@ private static void InvokeSpanSpanSpanIntoSpan( // Loop handling one vector at a time. do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i)), - Unsafe.As>(ref Unsafe.Add(ref yRef, i)), - Unsafe.As>(ref Unsafe.Add(ref zRef, i))); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i), + AsVector(ref zRef, i)); i += Vector.Count; } @@ -378,11 +201,9 @@ private static void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), - Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex)), - Unsafe.As>(ref Unsafe.Add(ref zRef, lastVectorIndex))); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + AsVector(ref zRef, lastVectorIndex)); } return; @@ -392,17 +213,16 @@ private static void InvokeSpanSpanSpanIntoSpan( // Loop handling one element at a time. while (i < x.Length) { - Unsafe.Add(ref dRef, i) = op.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); i++; } } private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op) + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) where TTernaryOperator : ITernaryOperator { if (x.Length != y.Length) @@ -425,16 +245,14 @@ private static void InvokeSpanSpanScalarIntoSpan( oneVectorFromEnd = x.Length - Vector.Count; if (oneVectorFromEnd >= 0) { - Vector zVec = new Vector(z); + Vector zVec = new(z); // Loop handling one vector at a time. do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i)), - Unsafe.As>(ref Unsafe.Add(ref yRef, i)), - zVec); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i), + zVec); i += Vector.Count; } @@ -444,11 +262,9 @@ private static void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), - Unsafe.As>(ref Unsafe.Add(ref yRef, lastVectorIndex)), - zVec); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + zVec); } return; @@ -458,17 +274,16 @@ private static void InvokeSpanSpanScalarIntoSpan( // Loop handling one element at a time. while (i < x.Length) { - Unsafe.Add(ref dRef, i) = op.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); i++; } } private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op) + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) where TTernaryOperator : ITernaryOperator { if (x.Length != z.Length) @@ -491,16 +306,14 @@ private static void InvokeSpanScalarSpanIntoSpan( oneVectorFromEnd = x.Length - Vector.Count; if (oneVectorFromEnd >= 0) { - Vector yVec = new Vector(y); + Vector yVec = new(y); // Loop handling one vector at a time. do { - Unsafe.As>(ref Unsafe.Add(ref dRef, i)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, i)), - yVec, - Unsafe.As>(ref Unsafe.Add(ref zRef, i))); + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + yVec, + AsVector(ref zRef, i)); i += Vector.Count; } @@ -510,11 +323,9 @@ private static void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - Unsafe.As>(ref Unsafe.Add(ref dRef, lastVectorIndex)) = - op.Invoke( - Unsafe.As>(ref Unsafe.Add(ref xRef, lastVectorIndex)), - yVec, - Unsafe.As>(ref Unsafe.Add(ref zRef, lastVectorIndex))); + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + yVec, + AsVector(ref zRef, lastVectorIndex)); } return; @@ -524,15 +335,19 @@ private static void InvokeSpanScalarSpanIntoSpan( // Loop handling one element at a time. while (i < x.Length) { - Unsafe.Add(ref dRef, i) = op.Invoke( - Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); i++; } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + private readonly struct AddOperator : IBinaryOperator { public float Invoke(float x, float y) => x + y; diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index ebe79895196ac..cc8d423f5a4d0 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -5,28 +5,14 @@ namespace System { - // - // This pattern of easily inlinable "void Throw" routines that stack on top of NoInlining factory methods - // is a compromise between older JITs and newer JITs (RyuJIT in .NET Core 1.1.0+ and .NET Framework in 4.6.3+). - // This package is explicitly targeted at older JITs as newer runtimes expect to implement Span intrinsically for - // best performance. - // - // The aim of this pattern is three-fold - // 1. Extracting the throw makes the method preforming the throw in a conditional branch smaller and more inlinable - // 2. Extracting the throw from generic method to non-generic method reduces the repeated codegen size for value types - // 3a. Newer JITs will not inline the methods that only throw and also recognise them, move the call to cold section - // and not add stack prep and unwind before calling https://github.com/dotnet/coreclr/pull/6103 - // 3b. Older JITs will inline the throw itself and move to cold section; but not inline the non-inlinable exception - // factory methods - still maintaining advantages 1 & 2 - // - internal static class ThrowHelper { [DoesNotReturn] - public static void ThrowArgument_DestinationTooShort() => throw new ArgumentException(SR.Argument_DestinationTooShort); + public static void ThrowArgument_DestinationTooShort() => + throw new ArgumentException(SR.Argument_DestinationTooShort, "destination"); [DoesNotReturn] - public static void ThrowArgument_SpansMustHaveSameLength() - => throw new ArgumentException(SR.Argument_SpansMustHaveSameLength); + public static void ThrowArgument_SpansMustHaveSameLength() => + throw new ArgumentException(SR.Argument_SpansMustHaveSameLength); } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index a442df3d26af7..5a9912542a8c2 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -82,7 +82,7 @@ public static void AddTwoTensors_ThrowsForTooShortDestination(int tensorLength) float[] y = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } [Theory] @@ -109,7 +109,7 @@ public static void AddTensorAndScalar_ThrowsForTooShortDestination(int tensorLen float y = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } [Theory] @@ -147,7 +147,7 @@ public static void SubtractTwoTensors_ThrowsForTooShortDestination(int tensorLen float[] y = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } [Theory] @@ -174,7 +174,7 @@ public static void SubtractTensorAndScalar_ThrowsForTooShortDestination(int tens float y = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } [Theory] @@ -212,7 +212,7 @@ public static void MultiplyTwoTensors_ThrowsForTooShortDestination(int tensorLen float[] y = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } [Theory] @@ -239,7 +239,7 @@ public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination(int tens float y = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } [Theory] @@ -277,7 +277,7 @@ public static void DivideTwoTensors_ThrowsForTooShortDestination(int tensorLengt float[] y = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } [Theory] @@ -304,7 +304,7 @@ public static void DivideTensorAndScalar_ThrowsForTooShortDestination(int tensor float y = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } [Theory] @@ -329,7 +329,7 @@ public static void NegateTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Negate(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); } [Theory] @@ -382,7 +382,7 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDest float[] multiplier = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } [Theory] @@ -402,6 +402,18 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float multiplier = NextSingle(); + float[] destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + [Theory] [MemberData(nameof(TensorLengths))] public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination(int tensorLength) @@ -411,7 +423,7 @@ public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestinati float multiplier = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } [Theory] @@ -431,6 +443,18 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) + { + float[] x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + float[] multiplier = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + [Theory] [MemberData(nameof(TensorLengths))] public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination(int tensorLength) @@ -440,7 +464,7 @@ public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDest float[] multiplier = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } [Theory] @@ -493,7 +517,7 @@ public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDest float[] addend = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } [Theory] @@ -522,7 +546,7 @@ public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestinati float addend = NextSingle(); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } [Theory] @@ -551,7 +575,7 @@ public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDest float[] addend = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } [Theory] @@ -576,7 +600,7 @@ public static void ExpTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Exp(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); } [Theory] @@ -601,7 +625,7 @@ public static void LogTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Log(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); } [Theory] @@ -626,7 +650,7 @@ public static void CoshTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Cosh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); } [Theory] @@ -651,7 +675,7 @@ public static void SinhTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Sinh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); } [Theory] @@ -676,7 +700,7 @@ public static void TanhTensor_ThrowsForTooShortDestination(int tensorLength) float[] x = CreateAndFillTensor(tensorLength); float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Tanh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); } } }