@@ -472,55 +472,76 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472
472
p << stringifyClauseOrderKind (order.getValue ());
473
473
}
474
474
475
- // ===----------------------------------------------------------------------===//
476
- // Parser and printer for grainsize Clause
477
- // ===----------------------------------------------------------------------===//
478
-
479
- // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
475
+ template <typename ClauseTypeAttr, typename ClauseType>
480
476
static ParseResult
481
- parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
482
- std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
483
- Type &grainsizeType) {
484
- SMLoc loc = parser.getCurrentLocation ();
477
+ parseGranularityClause (OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478
+ std::optional<OpAsmParser::UnresolvedOperand> &operand,
479
+ Type &operandType,
480
+ std::optional<ClauseType> (*symbolizeClause)(StringRef),
481
+ StringRef clauseName) {
485
482
StringRef enumStr;
486
-
487
483
if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
488
- if (std::optional<ClauseGrainsizeType> enumValue =
489
- symbolizeClauseGrainsizeType (enumStr)) {
490
- grainsizeMod =
491
- ClauseGrainsizeTypeAttr::get (parser.getContext (), *enumValue);
484
+ if (std::optional<ClauseType> enumValue = symbolizeClause (enumStr)) {
485
+ prescriptiveness = ClauseTypeAttr::get (parser.getContext (), *enumValue);
492
486
if (parser.parseColon ())
493
487
return failure ();
494
488
} else {
495
- return parser.emitError (loc, " invalid grainsize modifier : '" )
496
- << enumStr << " '" ;
489
+ return parser.emitError (parser.getCurrentLocation ())
490
+ << " invalid " << clauseName << " modifier : '" << enumStr << " '" ;
491
+ ;
497
492
}
498
493
}
499
494
500
- OpAsmParser::UnresolvedOperand operand ;
501
- if (succeeded (parser.parseOperand (operand ))) {
502
- grainsize = operand ;
495
+ OpAsmParser::UnresolvedOperand var ;
496
+ if (succeeded (parser.parseOperand (var ))) {
497
+ operand = var ;
503
498
} else {
504
499
return parser.emitError (parser.getCurrentLocation ())
505
- << " expected grainsize operand" ;
500
+ << " expected " << clauseName << " operand" ;
506
501
}
507
502
508
- if (grainsize .has_value ()) {
509
- if (parser.parseColonType (grainsizeType ))
503
+ if (operand .has_value ()) {
504
+ if (parser.parseColonType (operandType ))
510
505
return failure ();
511
506
}
512
507
513
508
return success ();
514
509
}
515
510
511
+ template <typename ClauseTypeAttr, typename ClauseType>
512
+ static void
513
+ printGranularityClause (OpAsmPrinter &p, Operation *op,
514
+ ClauseTypeAttr prescriptiveness, Value operand,
515
+ mlir::Type operandType,
516
+ StringRef (*stringifyClauseType)(ClauseType)) {
517
+
518
+ if (prescriptiveness)
519
+ p << stringifyClauseType (prescriptiveness.getValue ()) << " : " ;
520
+
521
+ if (operand)
522
+ p << operand << " : " << operandType;
523
+ }
524
+
525
+ // ===----------------------------------------------------------------------===//
526
+ // Parser and printer for grainsize Clause
527
+ // ===----------------------------------------------------------------------===//
528
+
529
+ // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530
+ static ParseResult
531
+ parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532
+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533
+ Type &grainsizeType) {
534
+ return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535
+ parser, grainsizeMod, grainsize, grainsizeType,
536
+ &symbolizeClauseGrainsizeType, " grainsize" );
537
+ }
538
+
516
539
static void printGrainsizeClause (OpAsmPrinter &p, Operation *op,
517
540
ClauseGrainsizeTypeAttr grainsizeMod,
518
541
Value grainsize, mlir::Type grainsizeType) {
519
- if (grainsizeMod)
520
- p << stringifyClauseGrainsizeType (grainsizeMod.getValue ()) << " : " ;
521
-
522
- if (grainsize)
523
- p << grainsize << " : " << grainsizeType;
542
+ printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543
+ p, op, grainsizeMod, grainsize, grainsizeType,
544
+ &stringifyClauseGrainsizeType);
524
545
}
525
546
526
547
// ===----------------------------------------------------------------------===//
@@ -532,46 +553,16 @@ static ParseResult
532
553
parseNumTasksClause (OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
533
554
std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
534
555
Type &numTasksType) {
535
- SMLoc loc = parser.getCurrentLocation ();
536
- StringRef enumStr;
537
-
538
- if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
539
- if (std::optional<ClauseNumTasksType> enumValue =
540
- symbolizeClauseNumTasksType (enumStr)) {
541
- numTasksMod =
542
- ClauseNumTasksTypeAttr::get (parser.getContext (), *enumValue);
543
- if (parser.parseColon ())
544
- return failure ();
545
- } else {
546
- return parser.emitError (loc, " invalid numTasks modifier : '" )
547
- << enumStr << " '" ;
548
- }
549
- }
550
-
551
- OpAsmParser::UnresolvedOperand operand;
552
- if (succeeded (parser.parseOperand (operand))) {
553
- numTasks = operand;
554
- } else {
555
- return parser.emitError (parser.getCurrentLocation ())
556
- << " expected num_tasks operand" ;
557
- }
558
-
559
- if (numTasks.has_value ()) {
560
- if (parser.parseColonType (numTasksType))
561
- return failure ();
562
- }
563
-
564
- return success ();
556
+ return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557
+ parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558
+ " num_tasks" );
565
559
}
566
560
567
561
static void printNumTasksClause (OpAsmPrinter &p, Operation *op,
568
562
ClauseNumTasksTypeAttr numTasksMod,
569
563
Value numTasks, mlir::Type numTasksType) {
570
- if (numTasksMod)
571
- p << stringifyClauseNumTasksType (numTasksMod.getValue ()) << " : " ;
572
-
573
- if (numTasks)
574
- p << numTasks << " : " << numTasksType;
564
+ printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565
+ p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
575
566
}
576
567
577
568
// ===----------------------------------------------------------------------===//
0 commit comments