Skip to content

Commit b8c55f8

Browse files
authored
Cache Exp in SoftMax calculation (#111615)
* fix softmax * Sum(destination)
1 parent 4974c12 commit b8c55f8

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Single.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -859,9 +859,9 @@ public static void SoftMax(ReadOnlySpan<float> x, Span<float> destination)
859859

860860
ValidateInputOutputSpanNonOverlapping(x, destination);
861861

862-
float expSum = Aggregate<ExpOperator_Single, AddOperator_Single>(x);
863-
864-
InvokeSpanScalarIntoSpan<ExpOperator_Single, DivideOperator_Single>(x, expSum, destination);
862+
InvokeSpanIntoSpan<ExpOperator_Single>(x, destination);
863+
float expSum = Sum(destination);
864+
InvokeSpanScalarIntoSpan<DivideOperator_Single>(destination, expSum, destination);
865865
}
866866

867867
/// <summary>Computes the element-wise difference between single-precision floating-point numbers in the specified tensors.</summary>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.SoftMax.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ public static void SoftMax<T>(ReadOnlySpan<T> x, Span<T> destination)
3636

3737
ValidateInputOutputSpanNonOverlapping(x, destination);
3838

39-
T expSum = Aggregate<T, ExpOperator<T>, AddOperator<T>>(x);
40-
41-
InvokeSpanScalarIntoSpan<T, ExpOperator<T>, DivideOperator<T>>(x, expSum, destination);
39+
InvokeSpanIntoSpan<T, ExpOperator<T>>(x, destination);
40+
T expSum = Sum(destination);
41+
InvokeSpanScalarIntoSpan<T, DivideOperator<T>>(destination, expSum, destination);
4242
}
4343
}
4444
}

0 commit comments

Comments
 (0)