diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index f9f7f45078a04..d70054a241126 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -121,6 +121,6 @@ Destination is too short. - Length of '{0}' must be same as length of '{1}'. + Input span arguments must all have the same length. - \ No newline at end of file + diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 5c188c2e0b841..097fa244ad491 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -16,6 +16,14 @@ + + + + + + + + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 10f29183ea286..08bd9d362217e 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -4,7 +4,7 @@ namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. - public static class TensorPrimitives + public static partial class TensorPrimitives { /// Computes the element-wise result of: + . /// The first tensor, represented as a span. @@ -13,23 +13,8 @@ public static class TensorPrimitives /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] + [i]. - public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] + y[i]; - } - } + public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise result of: + . /// The first tensor, represented as a span. @@ -37,18 +22,8 @@ public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span /// The destination tensor, represented as a span. /// Destination is too short. /// This method effectively does [i] = [i] + . - public static void Add(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] + y; - } - } + public static void Add(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of: - . /// The first tensor, represented as a span. @@ -57,23 +32,8 @@ public static void Add(ReadOnlySpan x, float y, Span destination) /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] - y[i]; - } - } + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise result of: - . /// The first tensor, represented as a span. @@ -81,18 +41,8 @@ public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, SpanThe destination tensor, represented as a span. /// Destination is too short. /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] - y; - } - } + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of: * . /// The first tensor, represented as a span. @@ -101,23 +51,8 @@ public static void Subtract(ReadOnlySpan x, float y, Span destinat /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] * y[i]; - } - } + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise result of: * . /// The first tensor, represented as a span. @@ -128,18 +63,8 @@ public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, SpanThis method effectively does [i] = [i] * . /// This method corresponds to the scal method defined by BLAS1. /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] * y; - } - } + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of: / . /// The first tensor, represented as a span. @@ -148,23 +73,8 @@ public static void Multiply(ReadOnlySpan x, float y, Span destinat /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] / y[i]; - } - } + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); /// Computes the element-wise result of: / . /// The first tensor, represented as a span. @@ -172,36 +82,16 @@ public static void Divide(ReadOnlySpan x, ReadOnlySpan y, SpanThe destination tensor, represented as a span. /// Destination is too short. /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = x[i] / y; - } - } + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); /// Computes the element-wise result of: -. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = -x[i]; - } - } + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); /// Computes the element-wise result of: ( + ) * . /// The first tensor, represented as a span. @@ -212,28 +102,8 @@ public static void Negate(ReadOnlySpan x, Span destination) /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length != multiplier.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(multiplier)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y[i]) * multiplier[i]; - } - } + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); /// Computes the element-wise result of: ( + ) * . /// The first tensor, represented as a span. @@ -243,23 +113,8 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, Rea /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y[i]) * multiplier; - } - } + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); /// Computes the element-wise result of: ( + ) * . /// The first tensor, represented as a span. @@ -269,23 +124,8 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, flo /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) - { - if (x.Length != multiplier.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(multiplier)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] + y) * multiplier[i]; - } - } + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); /// Computes the element-wise result of: ( * ) + . /// The first tensor, represented as a span. @@ -296,28 +136,8 @@ public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpanLength of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length != addend.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(addend)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y[i]) + addend[i]; - } - } + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); /// Computes the element-wise result of: ( * ) + . /// The first tensor, represented as a span. @@ -330,23 +150,8 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, Rea /// This method effectively does [i] = ([i] * [i]) + . /// This method corresponds to the axpy method defined by BLAS1. /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(y)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y[i]) + addend; - } - } + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); /// Computes the element-wise result of: ( * ) + . /// The first tensor, represented as a span. @@ -356,23 +161,8 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, flo /// Length of '' must be same as length of ''. /// Destination is too short. /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) - { - if (x.Length != addend.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(nameof(x), nameof(addend)); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = (x[i] * y) + addend[i]; - } - } + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); /// Computes the element-wise result of: pow(e, ). /// The tensor, represented as a span. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs new file mode 100644 index 0000000000000..1233f54901c80 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -0,0 +1,793 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + private static unsafe void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + + i++; + } + } + + private static unsafe void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + + i++; + } + } + + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + Vector512 yVec = Vector512.Create(y); + + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + Vector256 yVec = Vector256.Create(y); + + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + Vector128 yVec = Vector128.Create(y); + + // Loop handling one vector at a time. + do + { + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y); + + i++; + } + } + + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + Vector512 zVec = Vector512.Create(z); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + Vector512.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + Vector256 zVec = Vector256.Create(z); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + Vector256.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + Vector128 zVec = Vector128.Create(z); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + Vector128.LoadUnsafe(ref yRef, (uint)i), + zVec).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + zVec).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + + i++; + } + } + + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + Vector512 yVec = Vector512.Create(y); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector512.Count); + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + Vector256 yVec = Vector256.Create(y); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector256.Count); + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + Vector128 yVec = Vector128.Create(y); + + // Loop handling one vector at a time. + do + { + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), + yVec, + Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + uint lastVectorIndex = (uint)(x.Length - Vector128.Count); + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + private readonly struct AddOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; +#endif + } + + private readonly struct SubtractOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; +#endif + } + + private readonly struct MultiplyOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; +#endif + } + + private readonly struct DivideOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; +#endif + } + + private readonly struct NegateOperator : IUnaryOperator + { + public static float Invoke(float x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; + public static Vector512 Invoke(Vector512 x) => -x; + } + + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; + } + + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; + } + + private interface IUnaryOperator + { + static abstract float Invoke(float x); + static abstract Vector128 Invoke(Vector128 x); + static abstract Vector256 Invoke(Vector256 x); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x); +#endif + } + + private interface IBinaryOperator + { + static abstract float Invoke(float x, float y); + static abstract Vector128 Invoke(Vector128 x, Vector128 y); + static abstract Vector256 Invoke(Vector256 x, Vector256 y); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y); +#endif + } + + private interface ITernaryOperator + { + static abstract float Invoke(float x, float y, float z); + static abstract Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z); + static abstract Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z); +#endif + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs new file mode 100644 index 0000000000000..ddac0f47a685c --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -0,0 +1,411 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination, TUnaryOperator op = default) + where TUnaryOperator : IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i)); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex)); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i)); + + i++; + } + } + + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i)); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex)); + } + + return; + } + } + + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + + i++; + } + } + + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + Vector yVec = new(y); + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + yVec); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + yVec); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + y); + + i++; + } + } + + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + // Loop handling one vector at a time. + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i), + AsVector(ref zRef, i)); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + AsVector(ref zRef, lastVectorIndex)); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + Vector zVec = new(z); + + // Loop handling one vector at a time. + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + AsVector(ref yRef, i), + zVec); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + zVec); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + + i++; + } + } + + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + + if (Vector.IsHardwareAccelerated) + { + oneVectorFromEnd = x.Length - Vector.Count; + if (oneVectorFromEnd >= 0) + { + Vector yVec = new(y); + + // Loop handling one vector at a time. + do + { + AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), + yVec, + AsVector(ref zRef, i)); + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final vector. + if (i != x.Length) + { + int lastVectorIndex = x.Length - Vector.Count; + AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), + yVec, + AsVector(ref zRef, lastVectorIndex)); + } + + return; + } + } + + // Loop handling one element at a time. + while (i < x.Length) + { + Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + + i++; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + private readonly struct AddOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x + y; + public Vector Invoke(Vector x, Vector y) => x + y; + } + + private readonly struct SubtractOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x - y; + public Vector Invoke(Vector x, Vector y) => x - y; + } + + private readonly struct MultiplyOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x * y; + public Vector Invoke(Vector x, Vector y) => x * y; + } + + private readonly struct DivideOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x / y; + public Vector Invoke(Vector x, Vector y) => x / y; + } + + private readonly struct NegateOperator : IUnaryOperator + { + public float Invoke(float x) => -x; + public Vector Invoke(Vector x) => -x; + } + + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x + y) * z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; + } + + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x * y) + z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; + } + + private interface IUnaryOperator + { + float Invoke(float x); + Vector Invoke(Vector x); + } + + private interface IBinaryOperator + { + float Invoke(float x, float y); + Vector Invoke(Vector x, Vector y); + } + + private interface ITernaryOperator + { + float Invoke(float x, float y, float z); + Vector Invoke(Vector x, Vector y, Vector z); + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 5ae36ffbed768..cc8d423f5a4d0 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -5,28 +5,14 @@ namespace System { - // - // This pattern of easily inlinable "void Throw" routines that stack on top of NoInlining factory methods - // is a compromise between older JITs and newer JITs (RyuJIT in .NET Core 1.1.0+ and .NET Framework in 4.6.3+). - // This package is explicitly targeted at older JITs as newer runtimes expect to implement Span intrinsically for - // best performance. - // - // The aim of this pattern is three-fold - // 1. Extracting the throw makes the method preforming the throw in a conditional branch smaller and more inlinable - // 2. Extracting the throw from generic method to non-generic method reduces the repeated codegen size for value types - // 3a. Newer JITs will not inline the methods that only throw and also recognise them, move the call to cold section - // and not add stack prep and unwind before calling https://github.com/dotnet/coreclr/pull/6103 - // 3b. Older JITs will inline the throw itself and move to cold section; but not inline the non-inlinable exception - // factory methods - still maintaining advantages 1 & 2 - // - internal static class ThrowHelper { [DoesNotReturn] - public static void ThrowArgument_DestinationTooShort() => throw new ArgumentException(SR.Argument_DestinationTooShort); + public static void ThrowArgument_DestinationTooShort() => + throw new ArgumentException(SR.Argument_DestinationTooShort, "destination"); [DoesNotReturn] - public static void ThrowArgument_SpansMustHaveSameLength(string paramName1, string paramName2) - => throw new ArgumentException(SR.Format(SR.Argument_SpansMustHaveSameLength, paramName1, paramName2), paramName1); + public static void ThrowArgument_SpansMustHaveSameLength() => + throw new ArgumentException(SR.Argument_SpansMustHaveSameLength); } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 1e751543831c3..5a9912542a8c2 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -3,6 +3,7 @@ using Xunit; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; #pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 @@ -11,9 +12,9 @@ namespace System.Numerics.Tensors.Tests { public static class TensorPrimitivesTests { - private const int TensorSize = 512; - - private const int MismatchedTensorSize = 2; + public static IEnumerable TensorLengths => + from length in new[] { 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17, 31, 32, 33, 100 } + select new object[] { length }; private static readonly Random s_random = new Random(20230828); @@ -46,588 +47,660 @@ private static float NextSingle() #endif } - [Fact] - public static void AddTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]), destination[i]); } } - [Fact] - public static void AddTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void AddTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void AddTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y), destination[i]); } } - [Fact] - public static void AddTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); } - [Fact] - public static void SubtractTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Subtract(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] - y[i]), destination[i]); } } - [Fact] - public static void SubtractTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void SubtractTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void SubtractTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Subtract(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] - y), destination[i]); } } - [Fact] - public static void SubtractTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SubtractTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } - [Fact] - public static void MultiplyTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Multiply(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]), destination[i]); } } - [Fact] - public static void MultiplyTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void MultiplyTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void MultiplyTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Multiply(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y), destination[i]); } } - [Fact] - public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); } - [Fact] - public static void DivideTwoTensors() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Divide(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] / y[i]), destination[i]); } } - [Fact] - public static void DivideTwoTensors_ThrowsForMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void DivideTwoTensors_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTwoTensors_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void DivideTensorAndScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTensorAndScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Divide(x, y, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] / y), destination[i]); } } - [Fact] - public static void DivideTensorAndScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void DivideTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } - [Fact] - public static void NegateTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void NegateTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Negate(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(-x[i], destination[i]); } } - [Fact] - public static void NegateTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void NegateTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Negate(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]); } } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTwoTensorsAndMultiplyWithScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float multiplier = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y[i]) * multiplier, destination[i]); } } - [Fact] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); float multiplier = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void AddTensorAndScalarAndMultiplyWithTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float multiplier = NextSingle(); + float[] destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) + { + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.AddMultiply(x, y, multiplier, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] + y) * multiplier[i], destination[i]); } } - [Fact] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] multiplier = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] multiplier = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + float[] multiplier = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength) + { + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]) + addend[i], destination[i]); } } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(MismatchedTensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength - 1); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(MismatchedTensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength - 1); + float[] destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTwoTensorsAndAddWithScalar() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float addend = NextSingle(); - float[] destination = CreateTensor(TensorSize); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y[i]) + addend, destination[i]); } } - [Fact] - public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] y = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] y = CreateAndFillTensor(tensorLength); float addend = NextSingle(); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void MultiplyTensorAndScalarAndAddWithTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.MultiplyAdd(x, y, addend, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal((x[i] * y) + addend[i], destination[i]); } } - [Fact] - public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - float[] addend = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] addend = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); } - [Fact] - public static void ExpTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ExpTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Exp(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Exp(x[i]), destination[i]); } } - [Fact] - public static void ExpTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ExpTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Exp(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); } - [Fact] - public static void LogTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void LogTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Log(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Log(x[i]), destination[i]); } } - [Fact] - public static void LogTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void LogTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Log(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); } - [Fact] - public static void CoshTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CoshTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Cosh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Cosh(x[i]), destination[i]); } } - [Fact] - public static void CoshTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CoshTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Cosh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); } - [Fact] - public static void SinhTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SinhTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Sinh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Sinh(x[i]), destination[i]); } } - [Fact] - public static void SinhTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SinhTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Sinh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); } - [Fact] - public static void TanhTensor() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void TanhTensor(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(TensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength); TensorPrimitives.Tanh(x, destination); - for (int i = 0; i < TensorSize; i++) + for (int i = 0; i < tensorLength; i++) { Assert.Equal(MathF.Tanh(x[i]), destination[i]); } } - [Fact] - public static void TanhTensor_ThrowsForTooShortDestination() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void TanhTensor_ThrowsForTooShortDestination(int tensorLength) { - float[] x = CreateAndFillTensor(TensorSize); - float[] destination = CreateTensor(MismatchedTensorSize); + float[] x = CreateAndFillTensor(tensorLength); + float[] destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Tanh(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); } } }