Skip to content

Commit b7a6883

Browse files
committed
Using nvvm intrinsics for the syncthread and threadfence families of calls
1 parent 2546ae4 commit b7a6883

File tree

4 files changed

+103
-11
lines changed

4 files changed

+103
-11
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ struct IntrinsicLibrary {
392392
fir::ExtendedValue genSum(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
393393
void genSignalSubroutine(llvm::ArrayRef<fir::ExtendedValue>);
394394
void genSleep(llvm::ArrayRef<fir::ExtendedValue>);
395+
void genSyncThreads(llvm::ArrayRef<fir::ExtendedValue>);
396+
mlir::Value genSyncThreadsAnd(mlir::Type,llvm::ArrayRef<mlir::Value>);
397+
mlir::Value genSyncThreadsCount(mlir::Type,llvm::ArrayRef<mlir::Value>);
398+
mlir::Value genSyncThreadsOr(mlir::Type,llvm::ArrayRef<mlir::Value>);
395399
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
396400
mlir::ArrayRef<fir::ExtendedValue> args);
397401
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
@@ -401,6 +405,9 @@ struct IntrinsicLibrary {
401405
llvm::ArrayRef<fir::ExtendedValue>);
402406
fir::ExtendedValue genTranspose(mlir::Type,
403407
llvm::ArrayRef<fir::ExtendedValue>);
408+
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
409+
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
410+
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
404411
fir::ExtendedValue genTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
405412
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
406413
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,10 @@ static constexpr IntrinsicHandler handlers[]{
642642
{"dim", asValue},
643643
{"mask", asBox, handleDynamicOptional}}},
644644
/*isElemental=*/false},
645+
{"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
646+
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
647+
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
648+
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
645649
{"system",
646650
&I::genSystem,
647651
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -660,6 +664,9 @@ static constexpr IntrinsicHandler handlers[]{
660664
&I::genTranspose,
661665
{{{"matrix", asAddr}}},
662666
/*isElemental=*/false},
667+
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
668+
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
669+
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
663670
{"trim", &I::genTrim, {{{"string", asAddr}}}, /*isElemental=*/false},
664671
{"ubound",
665672
&I::genUbound,
@@ -7290,6 +7297,52 @@ IntrinsicLibrary::genSum(mlir::Type resultType,
72907297
resultType, args);
72917298
}
72927299

7300+
// SYNCTHREADS
7301+
void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) {
7302+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0";
7303+
mlir::FunctionType funcType =
7304+
mlir::FunctionType::get(builder.getContext(), {}, {});
7305+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7306+
llvm::SmallVector<mlir::Value> noArgs;
7307+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7308+
}
7309+
7310+
// SYNCTHREADS_AND
7311+
mlir::Value
7312+
IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
7313+
llvm::ArrayRef<mlir::Value> args) {
7314+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
7315+
mlir::MLIRContext *context = builder.getContext();
7316+
mlir::FunctionType ftype =
7317+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7318+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7319+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7320+
}
7321+
7322+
// SYNCTHREADS_COUNT
7323+
mlir::Value
7324+
IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
7325+
llvm::ArrayRef<mlir::Value> args) {
7326+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
7327+
mlir::MLIRContext *context = builder.getContext();
7328+
mlir::FunctionType ftype =
7329+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7330+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7331+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7332+
}
7333+
7334+
// SYNCTHREADS_OR
7335+
mlir::Value
7336+
IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
7337+
llvm::ArrayRef<mlir::Value> args) {
7338+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
7339+
mlir::MLIRContext *context = builder.getContext();
7340+
mlir::FunctionType ftype =
7341+
mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
7342+
auto funcOp = builder.createFunction(loc, funcName, ftype);
7343+
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
7344+
}
7345+
72937346
// SYSTEM
72947347
fir::ExtendedValue
72957348
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
@@ -7420,6 +7473,38 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType,
74207473
return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE");
74217474
}
74227475

7476+
// THREADFENCE
7477+
void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) {
7478+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl";
7479+
mlir::FunctionType funcType =
7480+
mlir::FunctionType::get(builder.getContext(), {}, {});
7481+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7482+
llvm::SmallVector<mlir::Value> noArgs;
7483+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7484+
}
7485+
7486+
// THREADFENCE_BLOCK
7487+
void IntrinsicLibrary::genThreadFenceBlock(
7488+
llvm::ArrayRef<fir::ExtendedValue> args) {
7489+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta";
7490+
mlir::FunctionType funcType =
7491+
mlir::FunctionType::get(builder.getContext(), {}, {});
7492+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7493+
llvm::SmallVector<mlir::Value> noArgs;
7494+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7495+
}
7496+
7497+
// THREADFENCE_SYSTEM
7498+
void IntrinsicLibrary::genThreadFenceSystem(
7499+
llvm::ArrayRef<fir::ExtendedValue> args) {
7500+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys";
7501+
mlir::FunctionType funcType =
7502+
mlir::FunctionType::get(builder.getContext(), {}, {});
7503+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7504+
llvm::SmallVector<mlir::Value> noArgs;
7505+
builder.create<fir::CallOp>(loc, funcOp, noArgs);
7506+
}
7507+
74237508
// TRIM
74247509
fir::ExtendedValue
74257510
IntrinsicLibrary::genTrim(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,27 @@ module cudadevice
1818
! Synchronization Functions
1919

2020
interface
21-
attributes(device) subroutine syncthreads() bind(c, name='__syncthreads')
21+
attributes(device) subroutine syncthreads()
2222
end subroutine
2323
end interface
2424
public :: syncthreads
2525

2626
interface
27-
attributes(device) integer function syncthreads_and(value) bind(c, name='__syncthreads_and')
27+
attributes(device) integer function syncthreads_and(value)
2828
integer :: value
2929
end function
3030
end interface
3131
public :: syncthreads_and
3232

3333
interface
34-
attributes(device) integer function syncthreads_count(value) bind(c, name='__syncthreads_count')
34+
attributes(device) integer function syncthreads_count(value)
3535
integer :: value
3636
end function
3737
end interface
3838
public :: syncthreads_count
3939

4040
interface
41-
attributes(device) integer function syncthreads_or(value) bind(c, name='__syncthreads_or')
41+
attributes(device) integer function syncthreads_or(value)
4242
integer :: value
4343
end function
4444
end interface
@@ -54,19 +54,19 @@ attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
5454
! Memory Fences
5555

5656
interface
57-
attributes(device) subroutine threadfence() bind(c, name='__threadfence')
57+
attributes(device) subroutine threadfence()
5858
end subroutine
5959
end interface
6060
public :: threadfence
6161

6262
interface
63-
attributes(device) subroutine threadfence_block() bind(c, name='__threadfence_block')
63+
attributes(device) subroutine threadfence_block()
6464
end subroutine
6565
end interface
6666
public :: threadfence_block
6767

6868
interface
69-
attributes(device) subroutine threadfence_system() bind(c, name='__threadfence_system')
69+
attributes(device) subroutine threadfence_system()
7070
end subroutine
7171
end interface
7272
public :: threadfence_system

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ attributes(global) subroutine devsub()
1717
end
1818

1919
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
20-
! CHECK: fir.call @__syncthreads()
20+
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
2121
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
2222
! CHECK: fir.call @__threadfence()
2323
! CHECK: fir.call @__threadfence_block()
2424
! CHECK: fir.call @__threadfence_system()
25-
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
26-
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
27-
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
25+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
26+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
27+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32
2828

2929
! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads", fir.proc_attrs = #fir.proc_attrs<bind_c>}
3030
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}

0 commit comments

Comments
 (0)