Skip to content

Commit 6c33c24

Browse files
committed
[MLIR][OpenMP]Add prescriptiveness-modifier support to grainsize and num_tasks clause.
1 parent a3093e5 commit 6c33c24

File tree

4 files changed

+159
-17
lines changed

4 files changed

+159
-17
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
436436
bit description = false, bit extraClassDeclaration = false
437437
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
438438
extraClassDeclaration> {
439-
let arguments = (ins
440-
Optional<IntLikeType>:$grainsize
441-
);
439+
let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
440+
Optional<IntLikeType>:$grainsize);
442441

443442
let optAssemblyFormat = [{
444-
`grainsize` `(` $grainsize `:` type($grainsize) `)`
443+
`grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
445444
}];
446445

447446
let description = [{
@@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
895894
bit description = false, bit extraClassDeclaration = false
896895
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
897896
extraClassDeclaration> {
898-
let arguments = (ins
899-
Optional<IntLikeType>:$num_tasks
900-
);
897+
let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
898+
Optional<IntLikeType>:$num_tasks);
901899

902900
let optAssemblyFormat = [{
903-
`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
901+
`num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
904902
}];
905903

906904
let description = [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,108 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472
p << stringifyClauseOrderKind(order.getValue());
473473
}
474474

475+
//===----------------------------------------------------------------------===//
476+
// Parser and printer for grainsize Clause
477+
//===----------------------------------------------------------------------===//
478+
479+
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
480+
static ParseResult
481+
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
482+
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
483+
Type &grainsizeType) {
484+
SMLoc loc = parser.getCurrentLocation();
485+
StringRef enumStr;
486+
487+
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
488+
if (std::optional<ClauseGrainsizeType> enumValue =
489+
symbolizeClauseGrainsizeType(enumStr)) {
490+
grainsizeMod =
491+
ClauseGrainsizeTypeAttr::get(parser.getContext(), *enumValue);
492+
if (parser.parseColon())
493+
return failure();
494+
} else {
495+
return parser.emitError(loc, "invalid grainsize modifier : '")
496+
<< enumStr << "'";
497+
}
498+
}
499+
500+
OpAsmParser::UnresolvedOperand operand;
501+
if (succeeded(parser.parseOperand(operand))) {
502+
grainsize = operand;
503+
} else {
504+
return parser.emitError(parser.getCurrentLocation())
505+
<< "expected grainsize operand";
506+
}
507+
508+
if (grainsize.has_value()) {
509+
if (parser.parseColonType(grainsizeType))
510+
return failure();
511+
}
512+
513+
return success();
514+
}
515+
516+
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
517+
ClauseGrainsizeTypeAttr grainsizeMod,
518+
Value grainsize, mlir::Type grainsizeType) {
519+
if (grainsizeMod)
520+
p << stringifyClauseGrainsizeType(grainsizeMod.getValue()) << ": ";
521+
522+
if (grainsize)
523+
p << grainsize << ": " << grainsizeType;
524+
}
525+
526+
//===----------------------------------------------------------------------===//
527+
// Parser and printer for num_tasks Clause
528+
//===----------------------------------------------------------------------===//
529+
530+
// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
531+
static ParseResult
532+
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
533+
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
534+
Type &numTasksType) {
535+
SMLoc loc = parser.getCurrentLocation();
536+
StringRef enumStr;
537+
538+
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
539+
if (std::optional<ClauseNumTasksType> enumValue =
540+
symbolizeClauseNumTasksType(enumStr)) {
541+
numTasksMod =
542+
ClauseNumTasksTypeAttr::get(parser.getContext(), *enumValue);
543+
if (parser.parseColon())
544+
return failure();
545+
} else {
546+
return parser.emitError(loc, "invalid numTasks modifier : '")
547+
<< enumStr << "'";
548+
}
549+
}
550+
551+
OpAsmParser::UnresolvedOperand operand;
552+
if (succeeded(parser.parseOperand(operand))) {
553+
numTasks = operand;
554+
} else {
555+
return parser.emitError(parser.getCurrentLocation())
556+
<< "expected num_tasks operand";
557+
}
558+
559+
if (numTasks.has_value()) {
560+
if (parser.parseColonType(numTasksType))
561+
return failure();
562+
}
563+
564+
return success();
565+
}
566+
567+
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
568+
ClauseNumTasksTypeAttr numTasksMod,
569+
Value numTasks, mlir::Type numTasksType) {
570+
if (numTasksMod)
571+
p << stringifyClauseNumTasksType(numTasksMod.getValue()) << ": ";
572+
573+
if (numTasks)
574+
p << numTasks << ": " << numTasksType;
575+
}
576+
475577
//===----------------------------------------------------------------------===//
476578
// Parsers for operations including clauses that define entry block arguments.
477579
//===----------------------------------------------------------------------===//
@@ -2593,15 +2695,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
25932695
const TaskloopOperands &clauses) {
25942696
MLIRContext *ctx = builder.getContext();
25952697
// TODO Store clauses in op: privateVars, privateSyms.
2596-
TaskloopOp::build(
2597-
builder, state, clauses.allocateVars, clauses.allocatorVars,
2598-
clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
2599-
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2600-
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2601-
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
2602-
/*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
2603-
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2604-
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2698+
TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2699+
clauses.final, clauses.grainsizeMod, clauses.grainsize,
2700+
clauses.ifExpr, clauses.inReductionVars,
2701+
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2702+
makeArrayAttr(ctx, clauses.inReductionSyms),
2703+
clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
2704+
clauses.numTasks, clauses.priority, /*private_vars=*/{},
2705+
/*private_syms=*/nullptr, clauses.reductionMod,
2706+
clauses.reductionVars,
2707+
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2708+
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
26052709
}
26062710

26072711
SmallVector<Value> TaskloopOp::getAllReductionVars() {

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20642064

20652065
// -----
20662066

2067+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
2068+
%testi64 = "test.i64"() : () -> (i64)
2069+
// expected-error @below {{invalid grainsize modifier : 'strict1'}}
2070+
omp.taskloop grainsize(strict1: %testi64: i64) {
2071+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2072+
omp.yield
2073+
}
2074+
}
2075+
return
2076+
}
2077+
// -----
2078+
2079+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
2080+
%testi64 = "test.i64"() : () -> (i64)
2081+
// expected-error @below {{invalid numTasks modifier : 'default'}}
2082+
omp.taskloop num_tasks(default: %testi64: i64) {
2083+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2084+
omp.yield
2085+
}
2086+
}
2087+
return
2088+
}
2089+
// -----
2090+
20672091
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20682092
// expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
20692093
omp.taskloop {

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
24172417
}
24182418
}
24192419

2420+
// CHECK: omp.taskloop grainsize(strict: %{{[^:]+}}: i64) {
2421+
omp.taskloop grainsize(strict: %testi64: i64) {
2422+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2423+
// CHECK: omp.yield
2424+
omp.yield
2425+
}
2426+
}
2427+
2428+
// CHECK: omp.taskloop num_tasks(strict: %{{[^:]+}}: i64) {
2429+
omp.taskloop num_tasks(strict: %testi64: i64) {
2430+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
2431+
// CHECK: omp.yield
2432+
omp.yield
2433+
}
2434+
}
2435+
24202436
// CHECK: omp.taskloop nogroup {
24212437
omp.taskloop nogroup {
24222438
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {

0 commit comments

Comments
 (0)