Skip to content

Commit 4cb1d01

Browse files
committed
[MLIR][OpenMP]Refactored parser and printer function of grainsize and numtasks clause.
1 parent 9bb3994 commit 4cb1d01

File tree

2 files changed

+54
-63
lines changed

2 files changed

+54
-63
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -472,55 +472,76 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472
p << stringifyClauseOrderKind(order.getValue());
473473
}
474474

475-
//===----------------------------------------------------------------------===//
476-
// Parser and printer for grainsize Clause
477-
//===----------------------------------------------------------------------===//
478-
479-
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
475+
template <typename ClauseTypeAttr, typename ClauseType>
480476
static ParseResult
481-
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
482-
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
483-
Type &grainsizeType) {
484-
SMLoc loc = parser.getCurrentLocation();
477+
parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478+
std::optional<OpAsmParser::UnresolvedOperand> &operand,
479+
Type &operandType,
480+
std::optional<ClauseType> (*symbolizeClause)(StringRef),
481+
StringRef clauseName) {
485482
StringRef enumStr;
486-
487483
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
488-
if (std::optional<ClauseGrainsizeType> enumValue =
489-
symbolizeClauseGrainsizeType(enumStr)) {
490-
grainsizeMod =
491-
ClauseGrainsizeTypeAttr::get(parser.getContext(), *enumValue);
484+
if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485+
prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
492486
if (parser.parseColon())
493487
return failure();
494488
} else {
495-
return parser.emitError(loc, "invalid grainsize modifier : '")
496-
<< enumStr << "'";
489+
return parser.emitError(parser.getCurrentLocation())
490+
<< "invalid " << clauseName << " modifier : '" << enumStr << "'";
491+
;
497492
}
498493
}
499494

500-
OpAsmParser::UnresolvedOperand operand;
501-
if (succeeded(parser.parseOperand(operand))) {
502-
grainsize = operand;
495+
OpAsmParser::UnresolvedOperand var;
496+
if (succeeded(parser.parseOperand(var))) {
497+
operand = var;
503498
} else {
504499
return parser.emitError(parser.getCurrentLocation())
505-
<< "expected grainsize operand";
500+
<< "expected " << clauseName << " operand";
506501
}
507502

508-
if (grainsize.has_value()) {
509-
if (parser.parseColonType(grainsizeType))
503+
if (operand.has_value()) {
504+
if (parser.parseColonType(operandType))
510505
return failure();
511506
}
512507

513508
return success();
514509
}
515510

511+
template <typename ClauseTypeAttr, typename ClauseType>
512+
static void
513+
printGranularityClause(OpAsmPrinter &p, Operation *op,
514+
ClauseTypeAttr prescriptiveness, Value operand,
515+
mlir::Type operandType,
516+
StringRef (*stringifyClauseType)(ClauseType)) {
517+
518+
if (prescriptiveness)
519+
p << stringifyClauseType(prescriptiveness.getValue()) << ": ";
520+
521+
if (operand)
522+
p << operand << ": " << operandType;
523+
}
524+
525+
//===----------------------------------------------------------------------===//
526+
// Parser and printer for grainsize Clause
527+
//===----------------------------------------------------------------------===//
528+
529+
// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530+
static ParseResult
531+
parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532+
std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533+
Type &grainsizeType) {
534+
return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535+
parser, grainsizeMod, grainsize, grainsizeType,
536+
&symbolizeClauseGrainsizeType, "grainsize");
537+
}
538+
516539
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
517540
ClauseGrainsizeTypeAttr grainsizeMod,
518541
Value grainsize, mlir::Type grainsizeType) {
519-
if (grainsizeMod)
520-
p << stringifyClauseGrainsizeType(grainsizeMod.getValue()) << ": ";
521-
522-
if (grainsize)
523-
p << grainsize << ": " << grainsizeType;
542+
printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543+
p, op, grainsizeMod, grainsize, grainsizeType,
544+
&stringifyClauseGrainsizeType);
524545
}
525546

526547
//===----------------------------------------------------------------------===//
@@ -532,46 +553,16 @@ static ParseResult
532553
parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
533554
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
534555
Type &numTasksType) {
535-
SMLoc loc = parser.getCurrentLocation();
536-
StringRef enumStr;
537-
538-
if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
539-
if (std::optional<ClauseNumTasksType> enumValue =
540-
symbolizeClauseNumTasksType(enumStr)) {
541-
numTasksMod =
542-
ClauseNumTasksTypeAttr::get(parser.getContext(), *enumValue);
543-
if (parser.parseColon())
544-
return failure();
545-
} else {
546-
return parser.emitError(loc, "invalid numTasks modifier : '")
547-
<< enumStr << "'";
548-
}
549-
}
550-
551-
OpAsmParser::UnresolvedOperand operand;
552-
if (succeeded(parser.parseOperand(operand))) {
553-
numTasks = operand;
554-
} else {
555-
return parser.emitError(parser.getCurrentLocation())
556-
<< "expected num_tasks operand";
557-
}
558-
559-
if (numTasks.has_value()) {
560-
if (parser.parseColonType(numTasksType))
561-
return failure();
562-
}
563-
564-
return success();
556+
return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557+
parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558+
"num_tasks");
565559
}
566560

567561
static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
568562
ClauseNumTasksTypeAttr numTasksMod,
569563
Value numTasks, mlir::Type numTasksType) {
570-
if (numTasksMod)
571-
p << stringifyClauseNumTasksType(numTasksMod.getValue()) << ": ";
572-
573-
if (numTasks)
574-
p << numTasks << ": " << numTasksType;
564+
printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565+
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
575566
}
576567

577568
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2078,7 +2078,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20782078

20792079
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
20802080
%testi64 = "test.i64"() : () -> (i64)
2081-
// expected-error @below {{invalid numTasks modifier : 'default'}}
2081+
// expected-error @below {{invalid num_tasks modifier : 'default'}}
20822082
omp.taskloop num_tasks(default: %testi64: i64) {
20832083
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
20842084
omp.yield

0 commit comments

Comments
 (0)