From 677038ae14769c63a019b1b2f88294f98e13ebfa Mon Sep 17 00:00:00 2001 From: Evelyne <110474206+evelyne-ringoot@users.noreply.github.com> Date: Mon, 13 Feb 2023 17:26:46 -0500 Subject: [PATCH 1/2] Changing transpose and ajoint functions from AbstractGPUArray to AnyGPUArray --- lib/GPUArraysCore/src/GPUArraysCore.jl | 4 +++- src/host/linalg.jl | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl index 1a78e270..2d4131fd 100644 --- a/lib/GPUArraysCore/src/GPUArraysCore.jl +++ b/lib/GPUArraysCore/src/GPUArraysCore.jl @@ -6,7 +6,7 @@ using Adapt ## essential types export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat, - WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle + WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, AnyGPUArray, AnyGPUMatrix """ AbstractGPUArray{T, N} <: DenseArray{T, N} @@ -24,6 +24,8 @@ const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T, # convenience aliases for working with wrapped arrays const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}} const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}} +const AnyGPUVector{T} = AnyGPUArray{T, 1} +const AnyGPUMatrix{T} = AnyGPUArray{T, 2} ## broadcasting diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 64f49241..5db045d9 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -29,9 +29,9 @@ function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector) B end -LinearAlgebra.transpose!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, B, A) -LinearAlgebra.adjoint!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, B, A) -function transpose_f!(f, B::AbstractGPUMatrix{T}, A::AbstractGPUMatrix{T}) where T +LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpose, B, A) +LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A) +function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f))) gpu_call(B, A) do ctx, B, A idx = @cartesianidx A From 0135f5e5a793b36b58ee104070cfe71d09bec606 Mon Sep 17 00:00:00 2001 From: Evelyne <110474206+evelyne-ringoot@users.noreply.github.com> Date: Tue, 14 Feb 2023 10:37:05 -0500 Subject: [PATCH 2/2] Update lib/GPUArraysCore/src/GPUArraysCore.jl Co-authored-by: Tim Besard --- lib/GPUArraysCore/src/GPUArraysCore.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/GPUArraysCore/src/GPUArraysCore.jl b/lib/GPUArraysCore/src/GPUArraysCore.jl index 2d4131fd..b2edff94 100644 --- a/lib/GPUArraysCore/src/GPUArraysCore.jl +++ b/lib/GPUArraysCore/src/GPUArraysCore.jl @@ -6,7 +6,8 @@ using Adapt ## essential types export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat, - WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, AnyGPUArray, AnyGPUMatrix + WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle, + AnyGPUArray, AnyGPUVector, AnyGPUMatrix """ AbstractGPUArray{T, N} <: DenseArray{T, N}