Skip to content

[MLIR][OpenMP]Add prescriptiveness-modifier support to grainsize and … #128477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
Optional<IntLikeType>:$grainsize
);
let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
Optional<IntLikeType>:$grainsize);

let optAssemblyFormat = [{
`grainsize` `(` $grainsize `:` type($grainsize) `)`
`grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
}];

let description = [{
Expand Down Expand Up @@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
Optional<IntLikeType>:$num_tasks
);
let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
Optional<IntLikeType>:$num_tasks);

let optAssemblyFormat = [{
`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
`num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
}];

let description = [{
Expand Down
113 changes: 104 additions & 9 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
p << stringifyClauseOrderKind(order.getValue());
}

template <typename ClauseTypeAttr, typename ClauseType>
static ParseResult
parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
std::optional<OpAsmParser::UnresolvedOperand> &operand,
Type &operandType,
std::optional<ClauseType> (*symbolizeClause)(StringRef),
StringRef clauseName) {
StringRef enumStr;
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
if (parser.parseComma())
return failure();
} else {
return parser.emitError(parser.getCurrentLocation())
<< "invalid " << clauseName << " modifier : '" << enumStr << "'";
;
}
}

OpAsmParser::UnresolvedOperand var;
if (succeeded(parser.parseOperand(var))) {
operand = var;
} else {
return parser.emitError(parser.getCurrentLocation())
<< "expected " << clauseName << " operand";
}

if (operand.has_value()) {
if (parser.parseColonType(operandType))
return failure();
}

return success();
}

template <typename ClauseTypeAttr, typename ClauseType>
static void
printGranularityClause(OpAsmPrinter &p, Operation *op,
ClauseTypeAttr prescriptiveness, Value operand,
mlir::Type operandType,
StringRef (*stringifyClauseType)(ClauseType)) {

if (prescriptiveness)
p << stringifyClauseType(prescriptiveness.getValue()) << ", ";

if (operand)
p << operand << ": " << operandType;
}

//===----------------------------------------------------------------------===//
// Parser and printer for grainsize Clause
//===----------------------------------------------------------------------===//

// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
static ParseResult
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
Type &grainsizeType) {
return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
parser, grainsizeMod, grainsize, grainsizeType,
&symbolizeClauseGrainsizeType, "grainsize");
}

static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
ClauseGrainsizeTypeAttr grainsizeMod,
Value grainsize, mlir::Type grainsizeType) {
printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
p, op, grainsizeMod, grainsize, grainsizeType,
&stringifyClauseGrainsizeType);
}

//===----------------------------------------------------------------------===//
// Parser and printer for num_tasks Clause
//===----------------------------------------------------------------------===//

// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
static ParseResult
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
Type &numTasksType) {
return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
"num_tasks");
}

static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
ClauseNumTasksTypeAttr numTasksMod,
Value numTasks, mlir::Type numTasksType) {
printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}

//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
const TaskloopOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
TaskloopOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
/*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.final, clauses.grainsizeMod, clauses.grainsize,
clauses.ifExpr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
makeArrayAttr(ctx, clauses.inReductionSyms),
clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
clauses.numTasks, clauses.priority, /*private_vars=*/{},
/*private_syms=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
}

SmallVector<Value> TaskloopOp::getAllReductionVars() {
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {

// -----

func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
// expected-error @below {{invalid grainsize modifier : 'strict1'}}
omp.taskloop grainsize(strict1, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
omp.yield
}
}
return
}
// -----

func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
%testi64 = "test.i64"() : () -> (i64)
// expected-error @below {{invalid num_tasks modifier : 'default'}}
omp.taskloop num_tasks(default, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
omp.yield
}
}
return
}
// -----

func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
omp.taskloop {
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
}
}

// CHECK: omp.taskloop grainsize(strict, %{{[^:]+}}: i64) {
omp.taskloop grainsize(strict, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// CHECK: omp.yield
omp.yield
}
}

// CHECK: omp.taskloop num_tasks(strict, %{{[^:]+}}: i64) {
omp.taskloop num_tasks(strict, %testi64: i64) {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
// CHECK: omp.yield
omp.yield
}
}

// CHECK: omp.taskloop nogroup {
omp.taskloop nogroup {
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
Expand Down