diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index a8d97a36df79e..32c28f72ec8e5 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { - let arguments = (ins - Optional:$grainsize - ); + let arguments = (ins OptionalAttr:$grainsize_mod, + Optional:$grainsize); let optAssemblyFormat = [{ - `grainsize` `(` $grainsize `:` type($grainsize) `)` + `grainsize` `(` custom($grainsize_mod , $grainsize, type($grainsize)) `)` }]; let description = [{ @@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { - let arguments = (ins - Optional:$num_tasks - ); + let arguments = (ins OptionalAttr:$num_tasks_mod, + Optional:$num_tasks); let optAssemblyFormat = [{ - `num_tasks` `(` $num_tasks `:` type($num_tasks) `)` + `num_tasks` `(` custom($num_tasks_mod , $num_tasks, type($num_tasks)) `)` }]; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d725a457aeff6..23ddc2fd53347 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op, p << stringifyClauseOrderKind(order.getValue()); } +template +static ParseResult +parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, + std::optional &operand, + Type &operandType, + std::optional (*symbolizeClause)(StringRef), + StringRef clauseName) { + StringRef enumStr; + if (succeeded(parser.parseOptionalKeyword(&enumStr))) { + if (std::optional enumValue = symbolizeClause(enumStr)) { + prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue); + if (parser.parseComma()) + return failure(); + } else { + return parser.emitError(parser.getCurrentLocation()) + << "invalid " << clauseName << " modifier : '" << enumStr << "'"; + ; + } + } + + OpAsmParser::UnresolvedOperand var; + if (succeeded(parser.parseOperand(var))) { + operand = var; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "expected " << clauseName << " operand"; + } + + if (operand.has_value()) { + if (parser.parseColonType(operandType)) + return failure(); + } + + return success(); +} + +template +static void +printGranularityClause(OpAsmPrinter &p, Operation *op, + ClauseTypeAttr prescriptiveness, Value operand, + mlir::Type operandType, + StringRef (*stringifyClauseType)(ClauseType)) { + + if (prescriptiveness) + p << stringifyClauseType(prescriptiveness.getValue()) << ", "; + + if (operand) + p << operand << ": " << operandType; +} + +//===----------------------------------------------------------------------===// +// Parser and printer for grainsize Clause +//===----------------------------------------------------------------------===// + +// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)` +static ParseResult +parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, + std::optional &grainsize, + Type &grainsizeType) { + return parseGranularityClause( + parser, grainsizeMod, grainsize, grainsizeType, + &symbolizeClauseGrainsizeType, "grainsize"); +} + +static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, + ClauseGrainsizeTypeAttr grainsizeMod, + Value grainsize, mlir::Type grainsizeType) { + printGranularityClause( + p, op, grainsizeMod, grainsize, grainsizeType, + &stringifyClauseGrainsizeType); +} + +//===----------------------------------------------------------------------===// +// Parser and printer for num_tasks Clause +//===----------------------------------------------------------------------===// + +// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)` +static ParseResult +parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, + std::optional &numTasks, + Type &numTasksType) { + return parseGranularityClause( + parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType, + "num_tasks"); +} + +static void printNumTasksClause(OpAsmPrinter &p, Operation *op, + ClauseNumTasksTypeAttr numTasksMod, + Value numTasks, mlir::Type numTasksType) { + printGranularityClause( + p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType); +} + //===----------------------------------------------------------------------===// // Parsers for operations including clauses that define entry block arguments. //===----------------------------------------------------------------------===// @@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state, const TaskloopOperands &clauses) { MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: privateVars, privateSyms. - TaskloopOp::build( - builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars, - makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), - makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, - clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{}, - /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); + TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.final, clauses.grainsizeMod, clauses.grainsize, + clauses.ifExpr, clauses.inReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), + makeArrayAttr(ctx, clauses.inReductionSyms), + clauses.mergeable, clauses.nogroup, clauses.numTasksMod, + clauses.numTasks, clauses.priority, /*private_vars=*/{}, + /*private_syms=*/nullptr, clauses.reductionMod, + clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); } SmallVector TaskloopOp::getAllReductionVars() { diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index d7f468bed3d3d..eeb124ebd5eb8 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { // ----- +func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { + %testi64 = "test.i64"() : () -> (i64) + // expected-error @below {{invalid grainsize modifier : 'strict1'}} + omp.taskloop grainsize(strict1, %testi64: i64) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + omp.yield + } + } + return +} +// ----- + +func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { + %testi64 = "test.i64"() : () -> (i64) + // expected-error @below {{invalid num_tasks modifier : 'default'}} + omp.taskloop num_tasks(default, %testi64: i64) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + omp.yield + } + } + return +} +// ----- + func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { // expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}} omp.taskloop { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e318afbebbf0c..72bb1db72377b 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { } } + // CHECK: omp.taskloop grainsize(strict, %{{[^:]+}}: i64) { + omp.taskloop grainsize(strict, %testi64: i64) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // CHECK: omp.yield + omp.yield + } + } + + // CHECK: omp.taskloop num_tasks(strict, %{{[^:]+}}: i64) { + omp.taskloop num_tasks(strict, %testi64: i64) { + omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { + // CHECK: omp.yield + omp.yield + } + } + // CHECK: omp.taskloop nogroup { omp.taskloop nogroup { omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {