Skip to content

(0.97.0) Move CUDA stuff to an extension #4499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6743772
Update .gitlab-ci.yml file
michel2323 Apr 11, 2025
314ddea
Adding Aurora CI
michel2323 Apr 11, 2025
5656758
Fix
michel2323 Apr 11, 2025
15e1fb1
Fix
michel2323 Apr 11, 2025
0eeb10f
Isolate CUDA
michel2323 May 12, 2025
146bbec
Create a CUDA extension
michel2323 May 13, 2025
d000990
Add basic CUDA extension test
michel2323 May 13, 2025
44b312d
Rebase and various fixes
michel2323 May 14, 2025
65884d4
No GPU in example
michel2323 May 20, 2025
a39d7f0
Fix unified_array
michel2323 May 20, 2025
01abe13
Add CUDA to test_init.jl
michel2323 Jun 5, 2025
56bb0f8
Move CUDA to KA in docstrings
michel2323 Jun 5, 2025
d511a1c
Populate AMDGPU extension
michel2323 Jun 12, 2025
6e151dd
AMDGPU fixes
michel2323 Jun 12, 2025
b052dc6
Fix MultiRegionObject
michel2323 Jun 18, 2025
38c5b77
Fix docs
michel2323 Jun 18, 2025
9238946
More MultiRegion fixes
michel2323 Jun 18, 2025
ae953fe
Fix test_tripolar_grid
michel2323 Jun 18, 2025
34ce8b0
Fix test_tripolar_grid
michel2323 Jun 18, 2025
adcb00e
Fix allowscalar
michel2323 Jun 19, 2025
c01eabf
One more MultiObject
michel2323 Jun 19, 2025
ae48a9c
Fix multi_region_implicit
michel2323 Jun 19, 2025
9b6104b
Fix docs
michel2323 Jun 19, 2025
36f28c0
backend -> device for now
michel2323 Jun 23, 2025
150bd0f
CI, why did I comment out this line?
michel2323 Jun 26, 2025
85f154e
Update OceananigansAMDGPUExt.jl
glwagner Jun 27, 2025
f862d5f
Update OceananigansCUDAExt.jl
glwagner Jun 27, 2025
eb2a138
Update src/MultiRegion/multi_region_utils.jl
glwagner Jun 27, 2025
14ee3f3
Merge branch 'main' into ms/ka
glwagner Jun 27, 2025
26dc29b
Update ext/OceananigansCUDAExt.jl
glwagner Jun 27, 2025
129b0dd
Update ext/OceananigansAMDGPUExt.jl
glwagner Jun 27, 2025
5ac15cf
Update ext/OceananigansAMDGPUExt.jl
glwagner Jun 27, 2025
8b9923d
Update src/Fields/set!.jl
glwagner Jun 27, 2025
83951f5
Merge branch 'main' into ms/ka
glwagner Jun 27, 2025
fb37d42
Merge branch 'main' into ms/ka
glwagner Jun 30, 2025
1d5f5b7
Merge branch 'main' into ms/ka
glwagner Jul 1, 2025
cdb70da
Merge branch 'main' into ms/ka
glwagner Jul 3, 2025
ad377b4
Merge branch 'main' into ms/ka
navidcy Jul 6, 2025
56c3c72
Merge remote-tracking branch 'upstream/main'
navidcy Jul 6, 2025
d2c19fa
Merge branch 'main' into ms/ka
navidcy Jul 6, 2025
22d62c2
leave empty line
navidcy Jul 6, 2025
4f79474
Delete .gitlab-ci.yml
navidcy Jul 6, 2025
492c259
load CUDA + disambiguate record method
navidcy Jul 6, 2025
97674d3
bump minor release
navidcy Jul 6, 2025
1b51351
Merge branch 'main' into ms/ka
navidcy Jul 6, 2025
0a73212
install CUDA
navidcy Jul 6, 2025
f065e2a
delete stray empty line
navidcy Jul 6, 2025
f960b63
add backend when creating MultiRegionObject
navidcy Jul 6, 2025
63539ed
missed =
navidcy Jul 6, 2025
05b1d99
remove some duplicate defs and gather tests together
navidcy Jul 6, 2025
609e5a7
convert_output(mo::MultiRegionObject, model) always on CPU?
navidcy Jul 6, 2025
984eb19
MultiRegionOuputWriter fix
michel2323 Jul 7, 2025
3209c40
let go of arch_array
navidcy Jul 7, 2025
1c6fa73
reorganize imports
navidcy Jul 7, 2025
ff957c8
add method for architecture(::Type{<:AbstractArray})
navidcy Jul 7, 2025
93a75d7
Merge branch 'main' into ms/ka
navidcy Jul 7, 2025
3696c43
Merge branch 'main' into ms/ka
simone-silvestri Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ version = "0.96.33"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand Down Expand Up @@ -40,6 +40,7 @@ TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Expand All @@ -51,6 +52,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
OceananigansAMDGPUExt = "AMDGPU"
OceananigansCUDAExt = "CUDA"
OceananigansEnzymeExt = "Enzyme"
OceananigansMakieExt = ["MakieCore", "Makie"]
OceananigansMetalExt = "Metal"
Expand Down Expand Up @@ -104,6 +106,7 @@ oneAPI = "2.0.1"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -116,4 +119,4 @@ TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[targets]
test = ["AMDGPU", "oneAPI", "DataDeps", "SafeTestsets", "Test", "Enzyme", "Reactant", "Metal", "CUDA_Runtime_jll", "MPIPreferences", "TimesDates", "NCDatasets"]
test = ["AMDGPU", "CUDA", "oneAPI", "DataDeps", "SafeTestsets", "Test", "Enzyme", "Reactant", "Metal", "CUDA_Runtime_jll", "MPIPreferences", "TimesDates", "NCDatasets"]
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Expand All @@ -15,6 +16,7 @@ TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a"

[compat]
CairoMakie = "0.11, 0.12, 0.13"
CUDA = "5.4"
Documenter = "1"
DocumenterCitations = "1"
JLD2 = "0.4, 0,5"
Expand Down
1 change: 1 addition & 0 deletions docs/src/grids.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ As a result, the grid was constructed by default on the CPU.
Next we build a grid on the _GPU_ that's two-dimensional in ``x, z`` and has variably-spaced cell interfaces in the `z`-direction,

```jldoctest grids_gpu
using CUDA
architecture = GPU()
z_faces = [0, 1, 3, 6, 10]

Expand Down
1 change: 1 addition & 0 deletions docs/src/model_setup/legacy_grids.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ architecture. By default `architecture = CPU()`. By providing `GPU()` as the `ar
we can construct the grid on GPU:

```julia
julia> using CUDA
julia> grid = RectilinearGrid(GPU(), size = (32, 64, 256), extent = (128, 256, 512))
32×64×256 RectilinearGrid{Float64, Periodic, Periodic, Bounded} on GPU with 3×3×3 halo
├── Periodic x ∈ [0.0, 128.0) regularly spaced with Δx=4.0
Expand Down
1 change: 1 addition & 0 deletions docs/src/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ CairoMakie.activate!(type = "png")
```@example gpu
using Oceananigans
using CairoMakie
using CUDA

grid = RectilinearGrid(GPU(),
size = (1024, 1024),
Expand Down
3 changes: 2 additions & 1 deletion examples/langmuir_turbulence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

# ```julia
# using Pkg
# pkg"add Oceananigans, CairoMakie"
# pkg"add Oceananigans, CairoMakie, CUDA"
# ```

using Oceananigans
using Oceananigans.Units: minute, minutes, hours
using CUDA

# ## Model set-up
#
Expand Down
108 changes: 90 additions & 18 deletions ext/OceananigansAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,101 @@
module OceananigansAMDGPUExt

using AMDGPU
using Oceananigans
using InteractiveUtils
using AMDGPU, AMDGPU.rocSPARSE, AMDGPU.rocFFT
using Oceananigans.Utils: linear_expand, __linear_ndrange, MappedCompilerMetadata
using KernelAbstractions: __dynamic_checkbounds, __iterspace
import KernelAbstractions: __validindex
using KernelAbstractions
import Oceananigans.Architectures as AC
import Oceananigans.BoundaryConditions as BC
import Oceananigans.DistributedComputations as DC
import Oceananigans.Fields as FD
import Oceananigans.Grids as GD
import Oceananigans.Solvers as SO
import Oceananigans.Utils as UT
import SparseArrays: SparseMatrixCSC
import KernelAbstractions: __iterspace, __groupindex, __dynamic_checkbounds,
__validindex, CompilerMetadata
import Oceananigans.DistributedComputations: Distributed

import Oceananigans.Architectures:
architecture,
convert_to_device,
on_architecture
const GPUVar = Union{ROCArray, Ptr}

const ROCGPU = GPU{<:AMDGPU.ROCBackend}
ROCGPU() = GPU(AMDGPU.ROCBackend())
function __init__()
if AMDGPU.functional()
@debug "ROCm-enabled GPU(s) detected:"
for (gpu, dev) in enumerate(AMDGPU.devices())
@debug "$dev: $(AMDGPU.name(dev))"
end
end
end

const ROCGPU = AC.GPU{ROCBackend}
ROCGPU() = AC.GPU(AMDGPU.ROCBackend())

architecture(::ROCArray) = ROCGPU()
Base.summary(::ROCGPU) = "ROCGPU"

on_architecture(::ROCGPU, a::Number) = a
on_architecture(::ROCGPU, a::Array) = ROCArray(a)
on_architecture(::ROCGPU, a::BitArray) = ROCArray(a)
on_architecture(::ROCGPU, a::SubArray{<:Any, <:Any, <:Array}) = ROCArray(a)
on_architecture(::CPU, a::ROCArray) = Array(a)
on_architecture(::CPU, a::SubArray{<:Any, <:Any, <:ROCArray}) = Array(a)
on_architecture(::ROCGPU, a::ROCArray) = a
on_architecture(::ROCGPU, a::SubArray{<:Any, <:Any, <:ROCArray}) = a
on_architecture(::ROCGPU, a::StepRangeLen) = a
AC.architecture(::ROCArray) = ROCGPU()
AC.architecture(::ROCSparseMatrixCSC) = ROCGPU()
AC.array_type(::AC.GPU{ROCBackend}) = ROCArray

AC.on_architecture(::ROCGPU, a::Number) = a
AC.on_architecture(::AC.CPU, a::ROCArray) = Array(a)
AC.on_architecture(::ROCGPU, a::Array) = ROCArray(a)
AC.on_architecture(::ROCGPU, a::ROCArray) = a
AC.on_architecture(::ROCGPU, a::BitArray) = ROCArray(a)
AC.on_architecture(::ROCGPU, a::SubArray{<:Any, <:Any, <:ROCArray}) = a
AC.on_architecture(::ROCGPU, a::SubArray{<:Any, <:Any, <:Array}) = ROCArray(a)
AC.on_architecture(::AC.CPU, a::SubArray{<:Any, <:Any, <:ROCArray}) = Array(a)
AC.on_architecture(::ROCGPU, a::StepRangeLen) = a
AC.on_architecture(arch::Distributed, a::ROCArray) = AC.on_architecture(AC.child_architecture(arch), a)
AC.on_architecture(arch::Distributed, a::SubArray{<:Any, <:Any, <:ROCArray}) = AC.on_architecture(child_architecture(arch), a)

function AC.unified_array(::ROCGPU, a::AbstractArray)
error("unified_array is not implemented for ROCGPU.")
end

## GPU to GPU copy of contiguous data
@inline function AC.device_copy_to!(dst::ROCArray, src::ROCArray; async::Bool = false)
if async == true
@warn "Asynchronous copy is not supported for ROCArray. Falling back to synchronous copy."
end
copyto!(dst, src)
return dst
end

@inline AC.unsafe_free!(a::ROCArray) = AMDGPU.unsafe_free!(a)

@inline AC.constructors(::AC.GPU{ROCBackend}, A::SparseMatrixCSC) = (ROCArray(A.colptr), ROCArray(A.rowval), ROCArray(A.nzval), (A.m, A.n))
@inline AC.constructors(::AC.CPU, A::ROCSparseMatrixCSC) = (A.dims[1], A.dims[2], Int64.(Array(A.colPtr)), Int64.(Array(A.rowVal)), Array(A.nzVal))
@inline AC.constructors(::AC.GPU{ROCBackend}, A::ROCSparseMatrixCSC) = (A.colPtr, A.rowVal, A.nzVal, A.dims)

@inline AC.arch_sparse_matrix(::AC.GPU{ROCBackend}, constr::Tuple) = ROCSparseMatrixCSC(constr...)
@inline AC.arch_sparse_matrix(::AC.CPU, A::ROCSparseMatrixCSC) = SparseMatrixCSC(AC.constructors(AC.CPU(), A)...)
@inline AC.arch_sparse_matrix(::AC.GPU{ROCBackend}, A::SparseMatrixCSC) = ROCSparseMatrixCSC(AC.constructors(AC.GPU(), A)...)
@inline AC.arch_sparse_matrix(::AC.GPU{ROCBackend}, A::ROCSparseMatrixCSC) = A

@inline convert_to_device(::ROCGPU, args) = AMDGPU.rocconvert(args)
@inline convert_to_device(::ROCGPU, args::Tuple) = map(AMDGPU.rocconvert, args)


BC.validate_boundary_condition_architecture(::ROCArray, ::AC.GPU, bc, side) = nothing

BC.validate_boundary_condition_architecture(::ROCArray, ::AC.CPU, bc, side) =
throw(ArgumentError("$side $bc must use `Array` rather than `ROCArray` on CPU architectures!"))

function SO.plan_forward_transform(A::ROCArray, ::Union{GD.Bounded, GD.Periodic}, dims, planner_flag)
length(dims) == 0 && return nothing
return AMDGPU.rocFFT.plan_fft!(A, dims)
end

FD.set!(v::Field, a::ROCArray) = FD._set!(v, a)
DC.set!(v::DC.DistributedField, a::ROCArray) = DC._set!(v, a)

function SO.plan_backward_transform(A::ROCArray, ::Union{GD.Bounded, GD.Periodic}, dims, planner_flag)
length(dims) == 0 && return nothing
return AMDGPU.rocFFT.plan_ifft!(A, dims)
end

AMDGPU.Device.@device_override @inline function __validindex(ctx::MappedCompilerMetadata)
if __dynamic_checkbounds(ctx)
I = @inbounds linear_expand(__iterspace(ctx), AMDGPU.Device.blockIdx().x, AMDGPU.Device.threadIdx().x)
Expand All @@ -39,4 +105,10 @@ AMDGPU.Device.@device_override @inline function __validindex(ctx::MappedCompiler
end
end

@inline UT.getdevice(roc::GPUVar, i) = device(roc)
@inline UT.getdevice(roc::GPUVar) = device(roc)
@inline UT.switch_device!(dev::Int64) = device!(dev)
@inline UT.sync_device!(::ROCGPU) = AMDGPU.synchronize()
@inline UT.sync_device!(::ROCBackend) = AMDGPU.synchronize()

end # module
138 changes: 138 additions & 0 deletions ext/OceananigansCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
module OceananigansCUDAExt

using Oceananigans
using InteractiveUtils
using CUDA, CUDA.CUSPARSE, CUDA.CUFFT
using Oceananigans.Utils: linear_expand, __linear_ndrange, MappedCompilerMetadata
using KernelAbstractions: __dynamic_checkbounds, __iterspace
using KernelAbstractions
import Oceananigans.Architectures as AC
import Oceananigans.BoundaryConditions as BC
import Oceananigans.DistributedComputations as DC
import Oceananigans.Fields as FD
import Oceananigans.Grids as GD
import Oceananigans.Solvers as SO
import Oceananigans.Utils as UT
import SparseArrays: SparseMatrixCSC
import KernelAbstractions: __iterspace, __groupindex, __dynamic_checkbounds,
__validindex, CompilerMetadata
import Oceananigans.DistributedComputations: Distributed

const GPUVar = Union{CuArray, CuContext, CuPtr, Ptr}

function __init__()
if CUDA.functional()
@debug "CUDA-enabled GPU(s) detected:"
for (gpu, dev) in enumerate(CUDA.devices())
@debug "$dev: $(CUDA.name(dev))"
end

CUDA.allowscalar(false)
end
end

const CUDAGPU = AC.GPU{<:CUDABackend}
CUDAGPU() = AC.GPU(CUDABackend(always_inline=true))

# Keep default CUDA backend
function AC.GPU()
if CUDA.has_cuda_gpu()
return CUDAGPU()
else
msg = """We cannot make a GPU with the CUDA backend:
a CUDA GPU was not found!"""
throw(ArgumentError(msg))
end
end

function UT.versioninfo_with_gpu(::CUDAGPU)
s = sprint(versioninfo)
gpu_name = CUDA.CuDevice(0) |> CUDA.name
return "CUDA GPU: $gpu_name"
end


Base.summary(::CUDAGPU) = "CUDAGPU"

AC.architecture(::CuArray) = CUDAGPU()
AC.architecture(::CuSparseMatrixCSC) = CUDAGPU()
AC.array_type(::AC.GPU{CUDABackend}) = CuArray

AC.on_architecture(::CUDAGPU, a::Number) = a
AC.on_architecture(::AC.CPU, a::CuArray) = Array(a)
AC.on_architecture(::CUDAGPU, a::Array) = CuArray(a)
AC.on_architecture(::CUDAGPU, a::CuArray) = a
AC.on_architecture(::CUDAGPU, a::BitArray) = CuArray(a)
AC.on_architecture(::CUDAGPU, a::SubArray{<:Any, <:Any, <:CuArray}) = a
AC.on_architecture(::CUDAGPU, a::SubArray{<:Any, <:Any, <:Array}) = CuArray(a)
AC.on_architecture(::AC.CPU, a::SubArray{<:Any, <:Any, <:CuArray}) = Array(a)
AC.on_architecture(::CUDAGPU, a::StepRangeLen) = a
AC.on_architecture(arch::Distributed, a::CuArray) = AC.on_architecture(AC.child_architecture(arch), a)
AC.on_architecture(arch::Distributed, a::SubArray{<:Any, <:Any, <:CuArray}) = AC.on_architecture(child_architecture(arch), a)

# cu alters the type of `a`, so we convert it back to the correct type
AC.unified_array(::CUDAGPU, a::AbstractArray) = map(eltype(a), cu(a; unified = true))

## GPU to GPU copy of contiguous data
@inline function AC.device_copy_to!(dst::CuArray, src::CuArray; async::Bool = false)
n = length(src)
context!(context(src)) do
GC.@preserve src dst begin
unsafe_copyto!(pointer(dst, 1), pointer(src, 1), n; async)
end
end
return dst
end

@inline AC.unsafe_free!(a::CuArray) = CUDA.unsafe_free!(a)

@inline AC.constructors(::AC.GPU{CUDABackend}, A::SparseMatrixCSC) = (CuArray(A.colptr), CuArray(A.rowval), CuArray(A.nzval), (A.m, A.n))
@inline AC.constructors(::AC.CPU, A::CuSparseMatrixCSC) = (A.dims[1], A.dims[2], Int64.(Array(A.colPtr)), Int64.(Array(A.rowVal)), Array(A.nzVal))
@inline AC.constructors(::AC.GPU{CUDABackend}, A::CuSparseMatrixCSC) = (A.colPtr, A.rowVal, A.nzVal, A.dims)

@inline AC.arch_sparse_matrix(::AC.GPU{CUDABackend}, constr::Tuple) = CuSparseMatrixCSC(constr...)
@inline AC.arch_sparse_matrix(::AC.CPU, A::CuSparseMatrixCSC) = SparseMatrixCSC(AC.constructors(AC.CPU(), A)...)
@inline AC.arch_sparse_matrix(::AC.GPU{CUDABackend}, A::SparseMatrixCSC) = CuSparseMatrixCSC(AC.constructors(AC.GPU(), A)...)
@inline AC.arch_sparse_matrix(::AC.GPU{CUDABackend}, A::CuSparseMatrixCSC) = A

@inline AC.convert_to_device(::CUDAGPU, args) = CUDA.cudaconvert(args)
@inline AC.convert_to_device(::CUDAGPU, args::Tuple) = map(CUDA.cudaconvert, args)


BC.validate_boundary_condition_architecture(::CuArray, ::AC.GPU, bc, side) = nothing

BC.validate_boundary_condition_architecture(::CuArray, ::AC.CPU, bc, side) =
throw(ArgumentError("$side $bc must use `Array` rather than `CuArray` on CPU architectures!"))

function SO.plan_forward_transform(A::CuArray, ::Union{GD.Bounded, GD.Periodic}, dims, planner_flag)
length(dims) == 0 && return nothing
return CUDA.CUFFT.plan_fft!(A, dims)
end

FD.set!(v::Field, a::CuArray) = FD._set!(v, a)
DC.set!(v::DC.DistributedField, a::CuArray) = DC._set!(v, a)

function SO.plan_backward_transform(A::CuArray, ::Union{GD.Bounded, GD.Periodic}, dims, planner_flag)
length(dims) == 0 && return nothing
return CUDA.CUFFT.plan_ifft!(A, dims)
end

# CUDA version, the indices are passed implicitly
# You must not use KA here as this code is executed in another scope
CUDA.@device_override @inline function __validindex(ctx::MappedCompilerMetadata)
if __dynamic_checkbounds(ctx)
index = @inbounds linear_expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
return index ≤ __linear_ndrange(ctx)
else
return true
end
end

@inline UT.sync_device!(::CuDevice) = CUDA.synchronize()
@inline UT.getdevice(cu::GPUVar, i) = device(cu)
@inline UT.getdevice(cu::GPUVar) = device(cu)
@inline UT.switch_device!(dev::CuDevice) = device!(dev)
@inline UT.sync_device!(::CUDAGPU) = CUDA.synchronize()
@inline UT.sync_device!(::CUDABackend) = CUDA.synchronize()

end # module OceananigansCUDAExt
1 change: 0 additions & 1 deletion src/AbstractOperations/AbstractOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ export Average, Integral, CumulativeIntegral, KernelFunctionOperation
export UnaryOperation, Derivative, BinaryOperation, MultiaryOperation, ConditionalOperation


using CUDA
using Base: @propagate_inbounds

using Oceananigans.Architectures
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractOperations/binary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ end
##### GPU capabilities
#####

"Adapt `BinaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
"Adapt `BinaryOperation` to work on the GPU via KernelAbstractions."
Adapt.adapt_structure(to, binary::BinaryOperation{LX, LY, LZ}) where {LX, LY, LZ} =
BinaryOperation{LX, LY, LZ}(Adapt.adapt(to, binary.op),
Adapt.adapt(to, binary.a),
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractOperations/kernel_function_operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end
indices(κ::KernelFunctionOperation) = construct_regionally(intersect_indices, location(κ), κ.arguments...)
compute_at!(κ::KernelFunctionOperation, time) = Tuple(compute_at!(d, time) for d in κ.arguments)

"Adapt `KernelFunctionOperation` to work on the GPU via CUDAnative and CUDAdrv."
"Adapt `KernelFunctionOperation` to work on the GPU via KernelAbstractions."
Adapt.adapt_structure(to, κ::KernelFunctionOperation{LX, LY, LZ}) where {LX, LY, LZ} =
KernelFunctionOperation{LX, LY, LZ}(Adapt.adapt(to, κ.kernel_function),
Adapt.adapt(to, κ.grid),
Expand Down
Loading