diff --git a/.gitignore b/.gitignore index 9b73c974..f70d3dc4 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ Manifest.toml # MacOS generated files *.DS_Store + +/.vscode/ diff --git a/Project.toml b/Project.toml index 551c9edc..33213c5a 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] @@ -24,5 +25,6 @@ Printf = "1" Random = "1" Reexport = "1" Serialization = "1" +SparseArrays = "1" Statistics = "1" julia = "1.10" diff --git a/lib/GPUArraysCore/Manifest.toml b/lib/GPUArraysCore/Manifest.toml deleted file mode 100644 index ca31d72b..00000000 --- a/lib/GPUArraysCore/Manifest.toml +++ /dev/null @@ -1,14 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.3" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/lib/GPUArraysCore/Project.toml b/lib/GPUArraysCore/Project.toml index 00ea6303..ba4b1ee4 100644 --- a/lib/GPUArraysCore/Project.toml +++ b/lib/GPUArraysCore/Project.toml @@ -5,7 +5,11 @@ version = "0.2.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] Adapt = "4.0" +LinearAlgebra = "1" +SparseArrays = "1" julia = "1.6" diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl index bcf24e60..409c8303 100644 --- a/lib/GPUArraysCore/src/GPUArraysCore.jl +++ b/lib/GPUArraysCore/src/GPUArraysCore.jl @@ -1,13 +1,17 @@ module GPUArraysCore using Adapt - +using LinearAlgebra +using SparseArrays ## essential types export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat, - WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, - AnyGPUArray, AnyGPUVector, AnyGPUMatrix + WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, + AnyGPUArray, AnyGPUVector, AnyGPUMatrix + +export AbstractGPUSparseArray, AbstractGPUSparseMatrix, AbstractGPUSparseVector, AbstractGPUSparseVecOrMat, + AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCOO, AnyGPUSparseMatrixCSC, AnyGPUSparseMatrixCSR, AnyGPUSparseMatrixCOO """ AbstractGPUArray{T, N} <: DenseArray{T, N} @@ -16,18 +20,33 @@ Supertype for `N`-dimensional GPU arrays (or array-like types) with elements of Instances of this type are expected to live on the host, see [`AbstractDeviceArray`](@ref) for device-side objects. """ -abstract type AbstractGPUArray{T, N} <: DenseArray{T, N} end +abstract type AbstractGPUArray{T,N} <: DenseArray{T,N} end -const AbstractGPUVector{T} = AbstractGPUArray{T, 1} -const AbstractGPUMatrix{T} = AbstractGPUArray{T, 2} -const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T, 2}} +const AbstractGPUVector{T} = AbstractGPUArray{T,1} +const AbstractGPUMatrix{T} = AbstractGPUArray{T,2} +const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T,1},AbstractGPUArray{T,2}} # convenience aliases for working with wrapped arrays const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}} -const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}} -const AnyGPUVector{T} = AnyGPUArray{T, 1} -const AnyGPUMatrix{T} = AnyGPUArray{T, 2} +const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N},WrappedGPUArray{T,N}} +const AnyGPUVector{T} = AnyGPUArray{T,1} +const AnyGPUMatrix{T} = AnyGPUArray{T,2} + +## sparse arrays + +abstract type AbstractGPUSparseArray{Tv,Ti,N} <: AbstractSparseArray{Tv,Ti,N} end + +const AbstractGPUSparseMatrix{Tv,Ti} = AbstractGPUSparseArray{Tv,Ti,2} +const AbstractGPUSparseVector{Tv,Ti} = AbstractGPUSparseArray{Tv,Ti,1} +const AbstractGPUSparseVecOrMat{Tv,Ti} = Union{AbstractGPUSparseVector{Tv,Ti},AbstractGPUSparseMatrix{Tv,Ti}} + +abstract type AbstractGPUSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end +abstract type AbstractGPUSparseMatrixCSR{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end +abstract type AbstractGPUSparseMatrixCOO{Tv,Ti<:Integer} <: AbstractGPUSparseMatrix{Tv,Ti} end +const AnyGPUSparseMatrixCSC{Tv,Ti} = Union{AbstractGPUSparseMatrixCSC{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCSC{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCSC{Tv,Ti}}} +const AnyGPUSparseMatrixCSR{Tv,Ti} = Union{AbstractGPUSparseMatrixCSR{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCSR{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCSR{Tv,Ti}}} +const AnyGPUSparseMatrixCOO{Tv,Ti} = Union{AbstractGPUSparseMatrixCOO{Tv,Ti},Transpose{Tv,<:AbstractGPUSparseMatrixCOO{Tv,Ti}},Adjoint{Tv,<:AbstractGPUSparseMatrixCOO{Tv,Ti}}} ## broadcasting @@ -157,9 +176,9 @@ end # this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217 macro __tryfinally(ex, fin) Expr(:tryfinally, - :($(esc(ex))), - :($(esc(fin))) - ) + :($(esc(ex))), + :($(esc(fin))) + ) end """ @@ -182,7 +201,7 @@ end function allowscalar(allow::Bool=true) if allow @warn """It's not recommended to use allowscalar([true]) to allow scalar indexing. - Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog=1 + Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog = 1 end setting = allow ? ScalarAllowed : ScalarDisallowed task_local_storage(:ScalarIndexing, setting) @@ -204,8 +223,8 @@ macro allowscalar(ex) local tls_value = get(task_local_storage(), :ScalarIndexing, nothing) task_local_storage(:ScalarIndexing, ScalarAllowed) @__tryfinally($(esc(ex)), - isnothing(tls_value) ? delete!(task_local_storage(), :ScalarIndexing) - : task_local_storage(:ScalarIndexing, tls_value)) + isnothing(tls_value) ? delete!(task_local_storage(), :ScalarIndexing) + : task_local_storage(:ScalarIndexing, tls_value)) end end diff --git a/lib/JLArrays/Project.toml b/lib/JLArrays/Project.toml index 55d9eb90..0c78e0af 100644 --- a/lib/JLArrays/Project.toml +++ b/lib/JLArrays/Project.toml @@ -8,10 +8,12 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] Adapt = "2.0, 3.0, 4.0" GPUArrays = "11.1" KernelAbstractions = "0.9" Random = "1" +SparseArrays = "1" julia = "1.8" diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index dbec3d25..94418265 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -11,6 +11,8 @@ export JLArray, JLVector, JLMatrix, jl, JLBackend using GPUArrays using Adapt +using SparseArrays +using SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds import KernelAbstractions import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config @@ -387,4 +389,6 @@ Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a) Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a) +include("sparse.jl") + end diff --git a/lib/JLArrays/src/sparse.jl b/lib/JLArrays/src/sparse.jl new file mode 100644 index 00000000..bf1f0bc8 --- /dev/null +++ b/lib/JLArrays/src/sparse.jl @@ -0,0 +1,95 @@ +export JLSparseVector, JLSparseMatrixCSC + +## Sparse Vector + +struct JLSparseVector{Tv,Ti<:Integer} <: AbstractGPUSparseVector{Tv,Ti} + n::Ti # Length of the sparse vector + nzind::JLVector{Ti} # Indices of stored values + nzval::JLVector{Tv} # Stored values, typically nonzeros + + function JLSparseVector{Tv,Ti}(n::Integer, nzind::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti<:Integer} + n >= 0 || throw(ArgumentError("The number of elements must be non-negative.")) + length(nzind) == length(nzval) || + throw(ArgumentError("index and value vectors must be the same length")) + new(convert(Ti, n), nzind, nzval) + end +end + +JLSparseVector(n::Integer, nzind::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti} = + JLSparseVector{Tv,Ti}(n, nzind, nzval) + +JLSparseVector(V::SparseVector) = JLSparseVector(V.n, JLVector(V.nzind), JLVector(V.nzval)) +SparseVector(V::JLSparseVector) = SparseVector(V.n, Vector(V.nzind), Vector(V.nzval)) + +Base.copy(V::JLSparseVector) = JLSparseVector(V.n, copy(V.nzind), copy(V.nzval)) + +Base.length(V::JLSparseVector) = V.n +Base.size(V::JLSparseVector) = (V.n,) + +SparseArrays.nonzeros(V::JLSparseVector) = V.nzval +SparseArrays.nonzeroinds(V::JLSparseVector) = V.nzind + +## SparseMatrixCSC + +struct JLSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrixCSC{Tv,Ti} + m::Int # Number of rows + n::Int # Number of columns + colptr::JLVector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1) + rowval::JLVector{Ti} # Row indices of stored values + nzval::JLVector{Tv} # Stored values, typically nonzeros + + function JLSparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::JLVector{Ti}, + rowval::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti<:Integer} + SparseArrays.sparse_check_Ti(m, n, Ti) + GPUArrays._goodbuffers_csc(m, n, colptr, rowval, nzval) || + throw(ArgumentError("Invalid buffers for JLSparseMatrixCSC construction n=$n, colptr=$(summary(colptr)), rowval=$(summary(rowval)), nzval=$(summary(nzval))")) + new(Int(m), Int(n), colptr, rowval, nzval) + end +end +function JLSparseMatrixCSC(m::Integer, n::Integer, colptr::JLVector, rowval::JLVector, nzval::JLVector) + Tv = eltype(nzval) + Ti = promote_type(eltype(colptr), eltype(rowval)) + SparseArrays.sparse_check_Ti(m, n, Ti) + # SparseArrays.sparse_check(n, colptr, rowval, nzval) # TODO: this uses scalar indexing + # silently shorten rowval and nzval to usable index positions. + maxlen = abs(widemul(m, n)) + isbitstype(Ti) && (maxlen = min(maxlen, typemax(Ti) - 1)) + length(rowval) > maxlen && resize!(rowval, maxlen) + length(nzval) > maxlen && resize!(nzval, maxlen) + JLSparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval) +end + +JLSparseMatrixCSC(A::SparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, JLVector(A.colptr), JLVector(A.rowval), JLVector(A.nzval)) +SparseMatrixCSC(A::JLSparseMatrixCSC) = SparseMatrixCSC(A.m, A.n, Vector(A.colptr), Vector(A.rowval), Vector(A.nzval)) + +Base.copy(A::JLSparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), copy(A.nzval)) + +Base.size(A::JLSparseMatrixCSC) = (A.m, A.n) +Base.length(A::JLSparseMatrixCSC) = A.m * A.n + +SparseArrays.nonzeros(A::JLSparseMatrixCSC) = A.nzval +SparseArrays.getcolptr(A::JLSparseMatrixCSC) = A.colptr +SparseArrays.rowvals(A::JLSparseMatrixCSC) = A.rowval + +## Device + +function Adapt.adapt_structure(to, A::JLSparseMatrixCSC) + m = A.m + n = A.n + colptr = Adapt.adapt(to, getcolptr(A)) + rowval = Adapt.adapt(to, rowvals(A)) + nzval = Adapt.adapt(to, nonzeros(A)) + return JLSparseDeviceMatrixCSC(m, n, colptr, rowval, nzval) +end + +struct JLSparseDeviceMatrixCSC{Tv,Ti} <: AbstractGPUSparseMatrixCSC{Tv,Ti} + m::Int + n::Int + colptr::JLDeviceArray{Ti,1} + rowval::JLDeviceArray{Ti,1} + nzval::JLDeviceArray{Tv,1} +end + +SparseArrays.nonzeros(A::JLSparseDeviceMatrixCSC) = A.nzval +SparseArrays.getcolptr(A::JLSparseDeviceMatrixCSC) = A.colptr +SparseArrays.rowvals(A::JLSparseDeviceMatrixCSC) = A.rowval diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 418b87b5..a47cd94d 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -4,6 +4,9 @@ using KernelAbstractions using Serialization using Random using LinearAlgebra +using SparseArrays +using SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds + using Printf using LinearAlgebra.BLAS @@ -15,8 +18,6 @@ using LLVM.Interop using Reexport @reexport using GPUArraysCore -using KernelAbstractions - # device functionality include("device/abstractarray.jl") @@ -33,6 +34,7 @@ include("host/math.jl") include("host/random.jl") include("host/quirks.jl") include("host/uniformscaling.jl") +include("host/sparse.jl") include("host/statistics.jl") diff --git a/src/host/sparse.jl b/src/host/sparse.jl new file mode 100644 index 00000000..01bad273 --- /dev/null +++ b/src/host/sparse.jl @@ -0,0 +1,181 @@ +## Wrappers + +trans_adj_wrappers_dense_vecormat = ((T -> :(AbstractGPUVecOrMat{$T}), false, identity, identity), + (T -> :(Transpose{$T,<:AbstractGPUMatrix{$T}}), true, identity, A -> :(parent($A))), + (T -> :(Adjoint{$T,<:AbstractGPUMatrix{$T}}), true, x -> :(conj($x)), A -> :(parent($A)))) + +trans_adj_wrappers_csc = ((T -> :(AbstractGPUSparseMatrixCSC{$T}), false, identity, identity), + (T -> :(Transpose{$T,<:AbstractGPUSparseMatrixCSC{$T}}), true, identity, A -> :(parent($A))), + (T -> :(Adjoint{$T,<:AbstractGPUSparseMatrixCSC{$T}}), true, x -> :(conj($x)), A -> :(parent($A)))) + +## Sparse Vector + +SparseArrays.getnzval(V::AbstractGPUSparseVector) = nonzeros(V) +SparseArrays.nnz(V::AbstractGPUSparseVector) = length(nzval(V)) + +function unsafe_free!(V::AbstractGPUSparseVector) + unsafe_free!(nonzeroinds(V)) + unsafe_free!(nonzeros(V)) + return nothing +end + +function Base.sizehint!(V::AbstractGPUSparseVector, newlen::Integer) + sizehint!(nonzeroinds(V), newlen) + sizehint!(nonzeros(V), newlen) + return V +end + +Base.copy(V::AbstractGPUSparseVector) = typeof(V)(length(V), copy(nonzeroinds(V)), copy(nonzeros(V))) +Base.similar(V::AbstractGPUSparseVector) = copy(V) # We keep the same sparsity of the source + +Base.:(*)(α::Number, V::AbstractGPUSparseVector) = typeof(V)(length(V), copy(nonzeroinds(V)), α * nonzeros(V)) +Base.:(*)(V::AbstractGPUSparseVector, α::Number) = α * V +Base.:(/)(V::AbstractGPUSparseVector, α::Number) = typeof(V)(length(V), copy(nonzeroinds(V)), nonzeros(V) / α) + +function LinearAlgebra.dot(x::AbstractGPUSparseVector, y::AbstractGPUVector) + n = length(y) + length(x) == n || throw(DimensionMismatch( + "Vector x has a length $(length(x)) but y has a length $n")) + nzind = nonzeroinds(x) + nzval = nonzeros(x) + y_view = y[nzind] # TODO: by using the view it throws scalar indexing + return dot(nzval, y_view) +end +LinearAlgebra.dot(x::AbstractGPUVector{T}, y::AbstractGPUSparseVector{T}) where {T<:Real} = dot(y, x) +LinearAlgebra.dot(x::AbstractGPUVector{T}, y::AbstractGPUSparseVector{T}) where {T<:Complex} = conj(dot(y, x)) + + +## General Sparse Matrix + +KernelAbstractions.get_backend(A::AbstractGPUSparseMatrix) = KernelAbstractions.get_backend(getnzval(A)) + +SparseArrays.getnzval(A::AbstractGPUSparseMatrix) = nonzeros(A) +SparseArrays.nnz(A::AbstractGPUSparseMatrix) = length(getnzval(A)) + +function LinearAlgebra.rmul!(A::AbstractGPUSparseMatrix, x::Number) + rmul!(getnzval(A), x) + return A +end + +function LinearAlgebra.lmul!(x::Number, A::AbstractGPUSparseMatrix) + lmul!(x, getnzval(A)) + return A +end + +## CSC Matrix + +SparseArrays.getrowval(A::AbstractGPUSparseMatrixCSC) = rowvals(A) +# SparseArrays.nzrange(A::AbstractGPUSparseMatrixCSC, col::Integer) = getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # TODO: this uses scalar indexing + +function unsafe_free!(A::AbstractGPUSparseMatrixCSC) + unsafe_free!(getcolptr(A)) + unsafe_free!(rowvals(A)) + unsafe_free!(nonzeros(A)) + return nothing +end + +Base.copy(A::AbstractGPUSparseMatrixCSC) = typeof(A)(size(A), copy(getcolptr(A)), copy(rowvals(A)), copy(getnzval(A))) +Base.similar(A::AbstractGPUSparseMatrixCSC) = copy(A) # We keep the same sparsity of the source + +Base.:(*)(α::Number, A::AbstractGPUSparseMatrixCSC) = typeof(A)(size(A), copy(getcolptr(A)), copy(rowvals(A)), α * nonzeros(A)) +Base.:(*)(A::AbstractGPUSparseMatrixCSC, α::Number) = α * A +Base.:(/)(A::AbstractGPUSparseMatrixCSC, α::Number) = typeof(A)(size(A), copy(getcolptr(A)), copy(rowvals(A)), nonzeros(A) / α) + +@inline function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA, A::AbstractGPUSparseMatrixCSC, B::AbstractGPUVector, _add::LinearAlgebra.MulAddMul) + return _spmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta) +end + +@inline function LinearAlgebra.generic_matmatmul!(C::AbstractGPUMatrix, tA, tb, A::AbstractGPUSparseMatrixCSC, B::AbstractGPUMatrix, _add::LinearAlgebra.MulAddMul) + return _spmatmul!(C, wrap(A, tA), wrap(B, tb), _add.alpha, _add.beta) +end + +for (wrapa, transa, opa, unwrapa) in trans_adj_wrappers_csc + for (wrapb, transb, opb, unwrapb) in trans_adj_wrappers_dense_vecormat + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + TypeC = :(AbstractGPUVecOrMat{T3}) + + kernel_spmatmul! = transa ? :kernel_spmatmul_T! : :kernel_spmatmul_N! + + indB = transb ? (i, j) -> :(($j, $i)) : (i, j) -> :(($i, $j)) # transpose indices + + @eval function _spmatmul!(C::$TypeC, A::$TypeA, B::$TypeB, α::Number, β::Number) where {T1,T2,T3} + size(A, 2) == size(B, 1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))")) + size(A, 1) == size(C, 1) || + throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))")) + size(B, 2) == size(C, 2) || + throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + + _A = $(unwrapa(:A)) + _B = $(unwrapb(:B)) + + backend_C = KernelAbstractions.get_backend(C) + backend_A = KernelAbstractions.get_backend(_A) + backend_B = KernelAbstractions.get_backend(_B) + + backend_A == backend_B == backend_C || throw(ArgumentError("All arrays must be on the same backend")) + + @kernel function kernel_spmatmul_N!(C, @Const(A), @Const(B)) + k, col = @index(Global, NTuple) + + Bi, Bj = $(indB(:col, :k)) + + @inbounds axj = $(opb(:(B[Bi, Bj]))) * α + @inbounds for j in getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # nzrange(A, col) + KernelAbstractions.@atomic C[getrowval(A)[j], k] += $(opa(:(getnzval(A)[j]))) * axj + end + end + + @kernel function kernel_spmatmul_T!(C, @Const(A), @Const(B)) + k, col = @index(Global, NTuple) + + tmp = zero(eltype(C)) + @inbounds for j in getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # nzrange(A, col) + Bi, Bj = $(indB(:(getrowval(A)[j]), :k)) + tmp += $(opa(:(getnzval(A)[j]))) * $(opb(:(B[Bi, Bj]))) + end + @inbounds C[col, k] += tmp * α + end + + β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + + kernel! = $kernel_spmatmul!(backend_A) + kernel!(C, _A, _B; ndrange=(size(C, 2), size(_A, 2))) + + return C + end + end +end + +function _goodbuffers_csc(m, n, colptr, rowval, nzval) + return (length(colptr) == n + 1 && length(rowval) == length(nzval)) + # TODO: also add the condition that colptr[end] - 1 == length(nzval) (allowscalar?) +end + +## Broadcasting + +# broadcast container type promotion for combinations of sparse arrays and other types +struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end +struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end +Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle() +Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle() +const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle} + +# GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions. +# GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions. +# Fall back to DefaultArrayStyle for higher dimensionality. +GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle() +GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle() +GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle() +GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle() +GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +Broadcast.BroadcastStyle(::GPUSparseMatStyle, ::GPUSparseVecStyle) = GPUSparseMatStyle() + +# Tuples promote to dense +Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}() +Broadcast.BroadcastStyle(::GPUSparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()