1
1
module OceananigansAMDGPUExt
2
2
3
- using AMDGPU
4
3
using Oceananigans
4
+ using InteractiveUtils
5
+ using AMDGPU, AMDGPU. ROCSPARSE, AMDGPU. ROCFFT
5
6
using Oceananigans. Utils: linear_expand, __linear_ndrange, MappedCompilerMetadata
6
7
using KernelAbstractions: __dynamic_checkbounds, __iterspace
7
- import KernelAbstractions: __validindex
8
+ using KernelAbstractions
9
+ import Oceananigans. Architectures as AC
10
+ import Oceananigans. BoundaryConditions as BC
11
+ import Oceananigans. DistributedComputations as DC
12
+ import Oceananigans. Fields as FD
13
+ import Oceananigans. Grids as GD
14
+ import Oceananigans. Solvers as SO
15
+ import Oceananigans. Utils as UT
16
+ import SparseArrays: SparseMatrixCSC
17
+ import KernelAbstractions: __iterspace, __groupindex, __dynamic_checkbounds,
18
+ __validindex, CompilerMetadata
19
+ import Oceananigans. DistributedComputations: Distributed
8
20
9
- import Oceananigans. Architectures:
10
- architecture,
11
- convert_to_device,
12
- on_architecture
21
+ const GPUVar = Union{ROCArray, CuContext, CuPtr, Ptr}
13
22
14
- const ROCGPU = GPU{<: AMDGPU.ROCBackend }
15
- ROCGPU () = GPU (AMDGPU. ROCBackend ())
23
+ function __init__ ()
24
+ if AMDGPU. functional ()
25
+ @debug " ROCm-enabled GPU(s) detected:"
26
+ for (gpu, dev) in enumerate (AMDGPU. devices ())
27
+ @debug " $dev : $(AMDGPU. name (dev)) "
28
+ end
29
+ end
30
+ end
31
+
32
+ const ROCGPU = AC. GPU{ROCBackend}
33
+ ROCGPU () = AC. GPU (AMDGPU. ROCBackend ())
16
34
17
35
architecture (:: ROCArray ) = ROCGPU ()
18
36
Base. summary (:: ROCGPU ) = " ROCGPU"
19
37
20
- on_architecture (:: ROCGPU , a:: Number ) = a
21
- on_architecture (:: ROCGPU , a:: Array ) = ROCArray (a)
22
- on_architecture (:: ROCGPU , a:: BitArray ) = ROCArray (a)
23
- on_architecture (:: ROCGPU , a:: SubArray{<:Any, <:Any, <:Array} ) = ROCArray (a)
24
- on_architecture (:: CPU , a:: ROCArray ) = Array (a)
25
- on_architecture (:: CPU , a:: SubArray{<:Any, <:Any, <:ROCArray} ) = Array (a)
26
- on_architecture (:: ROCGPU , a:: ROCArray ) = a
27
- on_architecture (:: ROCGPU , a:: SubArray{<:Any, <:Any, <:ROCArray} ) = a
28
- on_architecture (:: ROCGPU , a:: StepRangeLen ) = a
38
+ AC. architecture (:: ROCArray ) = ROCGPU ()
39
+ AC. architecture (:: ROCSparseMatrixCSC ) = ROCGPU ()
40
+ AC. array_type (:: AC.GPU{ROCBackend} ) = ROCArray
41
+
42
+ AC. on_architecture (:: ROCGPU , a:: Number ) = a
43
+ AC. on_architecture (:: AC.CPU , a:: ROCArray ) = Array (a)
44
+ AC. on_architecture (:: ROCGPU , a:: Array ) = ROCArray (a)
45
+ AC. on_architecture (:: ROCGPU , a:: ROCArray ) = a
46
+ AC. on_architecture (:: ROCGPU , a:: BitArray ) = ROCArray (a)
47
+ AC. on_architecture (:: ROCGPU , a:: SubArray{<:Any, <:Any, <:ROCArray} ) = a
48
+ AC. on_architecture (:: ROCGPU , a:: SubArray{<:Any, <:Any, <:Array} ) = ROCArray (a)
49
+ AC. on_architecture (:: CPU , a:: SubArray{<:Any, <:Any, <:ROCArray} ) = Array (a)
50
+ AC. on_architecture (:: ROCGPU , a:: StepRangeLen ) = a
51
+ AC. on_architecture (arch:: Distributed , a:: ROCArray ) = AC. on_architecture (AC. child_architecture (arch), a)
52
+ AC. on_architecture (arch:: Distributed , a:: SubArray{<:Any, <:Any, <:ROCArray} ) = AC. on_architecture (child_architecture (arch), a)
53
+
54
+ function unified_array (:: AMDGPU , a:: AbstractArray )
55
+ error (" unified_array is not implemented for ROCGPU." )
56
+ end
57
+
58
+ # # GPU to GPU copy of contiguous data
59
+ @inline function AC. device_copy_to! (dst:: ROCArray , src:: ROCArray ; async:: Bool = false )
60
+ if async == true
61
+ @warn " Asynchronous copy is not supported for ROCArray. Falling back to synchronous copy."
62
+ end
63
+ copyto! (dst, src)
64
+ return dst
65
+ end
66
+
67
+ @inline AC. unsafe_free! (a:: ROCArray ) = AMDGPU. unsafe_free! (a)
68
+
69
+ @inline AC. constructors (:: AC.GPU{ROCBackend} , A:: SparseMatrixCSC ) = (ROCArray (A. colptr), ROCArray (A. rowval), ROCArray (A. nzval), (A. m, A. n))
70
+ @inline AC. constructors (:: AC.CPU , A:: ROCSparseMatrixCSC ) = (A. dims[1 ], A. dims[2 ], Int64 .(Array (A. colPtr)), Int64 .(Array (A. rowVal)), Array (A. nzVal))
71
+ @inline AC. constructors (:: AC.GPU{ROCBackend} , A:: ROCSparseMatrixCSC ) = (A. colPtr, A. rowVal, A. nzVal, A. dims)
72
+
73
+ @inline AC. arch_sparse_matrix (:: AC.GPU{ROCBackend} , constr:: Tuple ) = ROCSparseMatrixCSC (constr... )
74
+ @inline AC. arch_sparse_matrix (:: AC.CPU , A:: ROCSparseMatrixCSC ) = SparseMatrixCSC (AC. constructors (AC. CPU (), A)... )
75
+ @inline AC. arch_sparse_matrix (:: AC.GPU{ROCBackend} , A:: SparseMatrixCSC ) = ROCSparseMatrixCSC (AC. constructors (AC. GPU (), A)... )
76
+ @inline AC. arch_sparse_matrix (:: AC.GPU{ROCBackend} , A:: ROCSparseMatrixCSC ) = A
29
77
30
78
@inline convert_to_device (:: ROCGPU , args) = AMDGPU. rocconvert (args)
31
79
@inline convert_to_device (:: ROCGPU , args:: Tuple ) = map (AMDGPU. rocconvert, args)
32
80
81
+
82
+ BC. validate_boundary_condition_architecture (:: ROCArray , :: AC.GPU , bc, side) = nothing
83
+
84
+ BC. validate_boundary_condition_architecture (:: ROCArray , :: AC.CPU , bc, side) =
85
+ throw (ArgumentError (" $side $bc must use `Array` rather than `ROCArray` on CPU architectures!" ))
86
+
87
+ function SO. plan_forward_transform (A:: ROCArray , :: Union{GD.Bounded, GD.Periodic} , dims, planner_flag)
88
+ length (dims) == 0 && return nothing
89
+ return AMDGPU. ROCFFT. plan_fft! (A, dims)
90
+ end
91
+
92
+ FD. set! (v:: Field , a:: ROCArray ) = FD. _set! (v, a)
93
+ DC. set! (v:: DC.DistributedField , a:: ROCArray ) = DC. _set! (v, a)
94
+
95
+ function SO. plan_backward_transform (A:: ROCArray , :: Union{GD.Bounded, GD.Periodic} , dims, planner_flag)
96
+ length (dims) == 0 && return nothing
97
+ return AMDGPU. ROCFFT. plan_ifft! (A, dims)
98
+ end
99
+
33
100
AMDGPU. Device. @device_override @inline function __validindex (ctx:: MappedCompilerMetadata )
34
101
if __dynamic_checkbounds (ctx)
35
102
I = @inbounds linear_expand (__iterspace (ctx), AMDGPU. Device. blockIdx (). x, AMDGPU. Device. threadIdx (). x)
@@ -39,4 +106,11 @@ AMDGPU.Device.@device_override @inline function __validindex(ctx::MappedCompiler
39
106
end
40
107
end
41
108
109
+ @inline UT. sync_device! (:: ROCDevice ) = ROC. synchronize ()
110
+ @inline UT. getdevice (roc:: GPUVar , i) = device (roc)
111
+ @inline UT. getdevice (roc:: GPUVar ) = device (roc)
112
+ @inline UT. switch_device! (dev:: ROCDevice ) = device! (dev)
113
+ @inline UT. sync_device! (:: ROCGPU ) = ROC. synchronize ()
114
+ @inline UT. sync_device! (:: ROCBackend ) = ROC. synchronize ()
115
+
42
116
end # module
0 commit comments