Skip to content

Reduce matmul latency by splitting small matmul #54421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 89 additions & 110 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -930,164 +930,143 @@ end


# multiply 2x2 matrices
function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
Base.@constprop :aggressive function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
function __matmul_checks(C, A, B, sz)
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (2,2))
if !(size(A) == size(B) == size(C) == sz)
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
return nothing
end

# separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
Base.@constprop :aggressive function _matmul2x2_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (2,2))
__matmul2x2_elements(tA, tB, A, B)
end
Base.@constprop :aggressive function __matmul2x2_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
if tA_uc == 'N'
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
elseif tA == 'T'
elseif tA_uc == 'T'
# TODO making these lazy could improve perf
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1]))
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2]))
elseif tA == 'C'
elseif tA_uc == 'C'
# TODO making these lazy could improve perf
A11 = copy(A[1,1]'); A12 = copy(A[2,1]')
A21 = copy(A[1,2]'); A22 = copy(A[2,2]')
elseif tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
elseif tA == 's'
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
elseif tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
end
if tB == 'N'
B11 = B[1,1]; B12 = B[1,2];
B21 = B[2,1]; B22 = B[2,2]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L)
elseif tA_uc == 'S'
if isuppercase(tA) # tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
else
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
end
elseif tA_uc == 'H'
if isuppercase(tA) # tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
end
end
end # inbounds
A11, A12, A21, A22
end
Base.@constprop :aggressive __matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B)

Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
(A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B)
@inbounds begin
_modify!(_add, A11*B11 + A12*B21, C, (1,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B11 + A22*B21, C, (2,1))
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
_modify!(_add, A21*B12 + A22*B22, C, (2,2))
end # inbounds
C
end

# Multiply 3x3 matrices
function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
Base.@constprop :aggressive function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if !(size(A) == size(B) == size(C) == (3,3))
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
end
# separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
Base.@constprop :aggressive function _matmul3x3_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
__matmul_checks(C, A, B, (3,3))
__matmul3x3_elements(tA, tB, A, B)
end
Base.@constprop :aggressive function __matmul3x3_elements(tA, A::AbstractMatrix)
@inbounds begin
if tA == 'N'
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
if tA_uc == 'N'
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3]
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3]
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3]
elseif tA == 'T'
elseif tA_uc == 'T'
# TODO making these lazy could improve perf
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2]))
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3]))
elseif tA == 'C'
elseif tA_uc == 'C'
# TODO making these lazy could improve perf
A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]')
A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]')
A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]')
elseif tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
elseif tA == 's'
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
elseif tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
elseif tA_uc == 'S'
if isuppercase(tA) # tA == 'S'
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
else
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
end
elseif tA_uc == 'H'
if isuppercase(tA) # tA == 'H'
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
else # if tA == 'h'
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
end
end
end # inbounds
A11, A12, A13, A21, A22, A23, A31, A32, A33
end
Base.@constprop :aggressive __matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B)

if tB == 'N'
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3]
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3]
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3]
elseif tB == 'T'
# TODO making these lazy could improve perf
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2]))
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3]))
elseif tB == 'C'
# TODO making these lazy could improve perf
B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]')
B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]')
B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]')
elseif tB == 'S'
B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3]
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U)
elseif tB == 's'
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L)
elseif tB == 'H'
B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3]
B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U)
else # if tB == 'h'
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1]))
B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2]))
B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L)
end
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())

_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
(A11, A12, A13, A21, A22, A23, A31, A32, A33),
(B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B)

@inbounds begin
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
_modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))

_modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1))

_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
_modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2))

_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
_modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3))
end # inbounds
C
Expand Down