Skip to content

Reduce matmul latency by splitting small matmul #54421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 12, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 43 additions & 74 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -934,15 +934,23 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
function __matmul_checks(C, A, B, sz)
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (2,2))
if !(size(A) == size(B) == size(C) == sz)
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
return nothing
end

# separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
function _matmul2x2_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (2,2))
__matmul2x2_elements(tA, tB, A, B)
end
function __matmul2x2_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
Expand All @@ -967,33 +975,18 @@ function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
end
if tB == 'N'
B11 = B[1,1]; B12 = B[1,2];
B21 = B[2,1]; B22 = B[2,2]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L)
end
end # inbounds
A11, A12, A21, A22
end
__matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B)

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you added aggressive constant propagation here? tA and tB should be known as constants at the call site, because they are obtained from unpeeling outer wrappers. Does that also increase latency again?

Copy link
Member Author

@jishnub jishnub May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried this, but this doesn't seem to change the compilation time at all. Some of the branches are certainly eliminated, but the bulk of the latency probably arises elsewhere.

This doesn't make things worse either, though, so perhaps we may include this, just in case the other issues are resolved in the future. This may also let us use lazy versions of the wrappers instead of copies, as this would be type-stable if the other branches are eliminated.

_add::MulAddMul = MulAddMul())
(A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B)
@inbounds begin
_modify!(_add, A11*B11 + A12*B21, C, (1,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B11 + A22*B21, C, (2,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B12 + A22*B22, C, (2,2))
end # inbounds
C
Expand All @@ -1004,15 +997,12 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (3,3))
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
# separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
function _matmul3x3_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (3,3))
__matmul3x3_elements(tA, tB, A, B)
end
function __matmul3x3_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3]
Expand Down Expand Up @@ -1045,49 +1035,28 @@ function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMat
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
end
end # inbounds
A11, A12, A13, A21, A22, A23, A31, A32, A33
end
__matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B)

if tB == 'N'
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3]
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3]
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2]))
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]')
B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3]
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3]
B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L)
end
function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())

_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
(A11, A12, A13, A21, A22, A23, A31, A32, A33),
(B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B)

@inbounds begin
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))

_modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1))

_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2))

_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
_modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3))
end # inbounds
C
Expand Down