@@ -472,6 +472,108 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472
472
p << stringifyClauseOrderKind (order.getValue ());
473
473
}
474
474
475
+ // ===----------------------------------------------------------------------===//
476
+ // Parser and printer for grainsize Clause
477
+ // ===----------------------------------------------------------------------===//
478
+
479
+ // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
480
+ static ParseResult
481
+ parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
482
+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
483
+ Type &grainsizeType) {
484
+ SMLoc loc = parser.getCurrentLocation ();
485
+ StringRef enumStr;
486
+
487
+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
488
+ if (std::optional<ClauseGrainsizeType> enumValue =
489
+ symbolizeClauseGrainsizeType (enumStr)) {
490
+ grainsizeMod =
491
+ ClauseGrainsizeTypeAttr::get (parser.getContext (), *enumValue);
492
+ if (parser.parseColon ())
493
+ return failure ();
494
+ } else {
495
+ return parser.emitError (loc, " invalid grainsize modifier : '" )
496
+ << enumStr << " '" ;
497
+ }
498
+ }
499
+
500
+ OpAsmParser::UnresolvedOperand operand;
501
+ if (succeeded (parser.parseOperand (operand))) {
502
+ grainsize = operand;
503
+ } else {
504
+ return parser.emitError (parser.getCurrentLocation ())
505
+ << " expected grainsize operand" ;
506
+ }
507
+
508
+ if (grainsize.has_value ()) {
509
+ if (parser.parseColonType (grainsizeType))
510
+ return failure ();
511
+ }
512
+
513
+ return success ();
514
+ }
515
+
516
+ static void printGrainsizeClause (OpAsmPrinter &p, Operation *op,
517
+ ClauseGrainsizeTypeAttr grainsizeMod,
518
+ Value grainsize, mlir::Type grainsizeType) {
519
+ if (grainsizeMod)
520
+ p << stringifyClauseGrainsizeType (grainsizeMod.getValue ()) << " : " ;
521
+
522
+ if (grainsize)
523
+ p << grainsize << " : " << grainsizeType;
524
+ }
525
+
526
+ // ===----------------------------------------------------------------------===//
527
+ // Parser and printer for num_tasks Clause
528
+ // ===----------------------------------------------------------------------===//
529
+
530
+ // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
531
+ static ParseResult
532
+ parseNumTasksClause (OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
533
+ std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
534
+ Type &numTasksType) {
535
+ SMLoc loc = parser.getCurrentLocation ();
536
+ StringRef enumStr;
537
+
538
+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
539
+ if (std::optional<ClauseNumTasksType> enumValue =
540
+ symbolizeClauseNumTasksType (enumStr)) {
541
+ numTasksMod =
542
+ ClauseNumTasksTypeAttr::get (parser.getContext (), *enumValue);
543
+ if (parser.parseColon ())
544
+ return failure ();
545
+ } else {
546
+ return parser.emitError (loc, " invalid numTasks modifier : '" )
547
+ << enumStr << " '" ;
548
+ }
549
+ }
550
+
551
+ OpAsmParser::UnresolvedOperand operand;
552
+ if (succeeded (parser.parseOperand (operand))) {
553
+ numTasks = operand;
554
+ } else {
555
+ return parser.emitError (parser.getCurrentLocation ())
556
+ << " expected num_tasks operand" ;
557
+ }
558
+
559
+ if (numTasks.has_value ()) {
560
+ if (parser.parseColonType (numTasksType))
561
+ return failure ();
562
+ }
563
+
564
+ return success ();
565
+ }
566
+
567
+ static void printNumTasksClause (OpAsmPrinter &p, Operation *op,
568
+ ClauseNumTasksTypeAttr numTasksMod,
569
+ Value numTasks, mlir::Type numTasksType) {
570
+ if (numTasksMod)
571
+ p << stringifyClauseNumTasksType (numTasksMod.getValue ()) << " : " ;
572
+
573
+ if (numTasks)
574
+ p << numTasks << " : " << numTasksType;
575
+ }
576
+
475
577
// ===----------------------------------------------------------------------===//
476
578
// Parsers for operations including clauses that define entry block arguments.
477
579
// ===----------------------------------------------------------------------===//
@@ -2593,15 +2695,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2593
2695
const TaskloopOperands &clauses) {
2594
2696
MLIRContext *ctx = builder.getContext ();
2595
2697
// TODO Store clauses in op: privateVars, privateSyms.
2596
- TaskloopOp::build (
2597
- builder, state, clauses.allocateVars , clauses.allocatorVars ,
2598
- clauses.final , clauses.grainsize , clauses.ifExpr , clauses.inReductionVars ,
2599
- makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2600
- makeArrayAttr (ctx, clauses.inReductionSyms ), clauses.mergeable ,
2601
- clauses.nogroup , clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2602
- /* private_syms=*/ nullptr , clauses.reductionMod , clauses.reductionVars ,
2603
- makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2604
- makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2698
+ TaskloopOp::build (builder, state, clauses.allocateVars , clauses.allocatorVars ,
2699
+ clauses.final , clauses.grainsizeMod , clauses.grainsize ,
2700
+ clauses.ifExpr , clauses.inReductionVars ,
2701
+ makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2702
+ makeArrayAttr (ctx, clauses.inReductionSyms ),
2703
+ clauses.mergeable , clauses.nogroup , clauses.numTasksMod ,
2704
+ clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2705
+ /* private_syms=*/ nullptr , clauses.reductionMod ,
2706
+ clauses.reductionVars ,
2707
+ makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2708
+ makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2605
2709
}
2606
2710
2607
2711
SmallVector<Value> TaskloopOp::getAllReductionVars () {
0 commit comments