@@ -111,10 +111,34 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
111
111
}
112
112
}
113
113
114
+ std::set<torch::jit::Node*> getDependentNodes (torch::jit::Node* n) {
115
+ std::set<torch::jit::Node*> dependent_nodes;
116
+ for (auto val : n->outputs ()) {
117
+ for (auto use : val->uses ()) {
118
+ dependent_nodes.insert (use.user );
119
+ }
120
+ }
121
+ if (const auto * schema = n->maybeSchema ()) {
122
+ for (size_t i = 0 ; i < n->inputs ().size (); ++i) {
123
+ const at::AliasInfo* formal = schema->arguments ()[i].alias_info ();
124
+ if (formal && formal->isWrite ()) {
125
+ for (auto use : n->inputs ()[i]->uses ()) {
126
+ torch::jit::Node* use_node = use.user ;
127
+ if (use_node->isAfter (n)) {
128
+ dependent_nodes.insert (use_node);
129
+ }
130
+ }
131
+ }
132
+ }
133
+ }
134
+ return dependent_nodes;
135
+ }
136
+
114
137
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
115
138
std::vector<torch::jit::Node*> traverseNodesForMinBlockSize (PartitioningCtx* ctx, torch::jit::Block* block) {
116
139
auto nodes = block->nodes ();
117
140
std::vector<torch::jit::Node*> cur_trt_nodes;
141
+ std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
118
142
std::vector<torch::jit::Node*> min_block_fallback_nodes;
119
143
for (const auto n : nodes) {
120
144
if (n->kind () == torch::jit::prim::Constant) {
@@ -124,11 +148,16 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
124
148
// check if current node fallback or not
125
149
if (!ctx->shouldNodeRunInTorch (n)) {
126
150
cur_trt_nodes.push_back (n);
151
+ auto dependent_nodes = getDependentNodes (n);
152
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
127
153
} else {
128
- if (cur_trt_nodes.size () < ctx->settings .min_block_size ) {
129
- min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
154
+ if (cur_trt_nodes_uses.count (n)) {
155
+ if (cur_trt_nodes.size () < ctx->settings .min_block_size ) {
156
+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
157
+ }
158
+ cur_trt_nodes.clear ();
159
+ cur_trt_nodes_uses.clear ();
130
160
}
131
- cur_trt_nodes.clear ();
132
161
}
133
162
}
134
163
if (cur_trt_nodes.size () < ctx->settings .min_block_size ) {
@@ -355,6 +384,59 @@ void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
355
384
setMinBlockFallbackNodes (ctx, block);
356
385
}
357
386
387
+ void merge_adjacent_segments_list_in_new_partition (
388
+ PartitionedGraph& original_partition,
389
+ PartitionedGraph& new_partition,
390
+ SegmentedBlock::SegmentedBlockTarget& segment_kind,
391
+ std::vector<size_t >& same_type_segment_idx) {
392
+ TORCHTRT_CHECK (!same_type_segment_idx.empty (), " Unable to merge empty segment list" );
393
+ if (same_type_segment_idx.size () == 1 ) {
394
+ new_partition.push_back (original_partition[same_type_segment_idx[0 ]]);
395
+ } else {
396
+ auto first_idx = same_type_segment_idx[0 ];
397
+ for (size_t i = 1 ; i < same_type_segment_idx.size (); ++i) {
398
+ TORCHTRT_CHECK (
399
+ same_type_segment_idx[i] == (first_idx + i),
400
+ " Unable to merge non-sequential segments: " << same_type_segment_idx);
401
+ }
402
+ LOG_DEBUG (
403
+ " Merging adjacent " << SegmentedBlock::target_to_str (segment_kind) << " segments: " << same_type_segment_idx);
404
+ std::vector<torch::jit::Node*> nodes;
405
+ for (auto segment_to_merge : same_type_segment_idx) {
406
+ const auto & merge_nodes = original_partition[segment_to_merge].raw_nodes ();
407
+ nodes.insert (nodes.end (), merge_nodes.begin (), merge_nodes.end ());
408
+ }
409
+ new_partition.emplace_back (segment_kind, nodes);
410
+ }
411
+ }
412
+
413
+ PartitionedGraph merge_adjacent_segments_of_same_type (PartitionedGraph& original_partition) {
414
+ PartitionedGraph new_partition;
415
+ SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch ;
416
+ std::vector<size_t > same_type_segment_idx;
417
+ for (size_t i = 0UL ; i < original_partition.size (); ++i) {
418
+ auto & segment = original_partition[i];
419
+ if (same_type_segment_idx.empty ()) {
420
+ segment_kind = segment.target ();
421
+ } else if (segment_kind != segment.target () || segment.do_not_merge ()) {
422
+ merge_adjacent_segments_list_in_new_partition (
423
+ original_partition, new_partition, segment_kind, same_type_segment_idx);
424
+ same_type_segment_idx.clear ();
425
+ segment_kind = segment.target ();
426
+ }
427
+ if (segment.do_not_merge ()) {
428
+ new_partition.push_back (segment);
429
+ } else {
430
+ same_type_segment_idx.push_back (i);
431
+ }
432
+ }
433
+ if (!same_type_segment_idx.empty ()) {
434
+ merge_adjacent_segments_list_in_new_partition (
435
+ original_partition, new_partition, segment_kind, same_type_segment_idx);
436
+ }
437
+ return new_partition;
438
+ }
439
+
358
440
void segmentGraph (PartitioningCtx* ctx, torch::jit::Block* block) {
359
441
// Find all the fallback nodes and build execution decision LUT for all nodes
360
442
setNodeExecutorLUT (ctx, block);
@@ -365,58 +447,75 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
365
447
PartitionedGraph segmented_blocks;
366
448
367
449
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
450
+ std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
451
+ std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
368
452
for (const auto n : nodes) {
369
453
// Skip constant nodes as they are resources for both kinds of modules
370
454
if (n->kind () == torch::jit::prim::Constant) {
371
455
continue ;
372
456
}
457
+ auto dependent_nodes = getDependentNodes (n);
373
458
// the outputs of trt subgraph shouldn't be collections
374
459
if (ctx->shouldNodeRunInTensorRT (n)) {
375
460
in_prog_trt_blk_nodes.push_back (n);
461
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
376
462
377
- // If there is an active PyTorch block and we have passed the threshold for a valid TRT
378
- // block then segment and reset the active PyTorch block
379
- if (in_prog_trt_blk_nodes. size () >= ctx-> settings . min_block_size && !in_prog_pyt_blk_nodes. empty ( )) {
463
+ // If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
464
+ // those dependencies in the graph
465
+ if (cur_pyt_nodes_uses. count (n )) {
380
466
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
467
+ cur_pyt_nodes_uses.clear ();
381
468
}
382
469
} else {
383
- // If there is an active TRT block that is valid segment and reset the active TRT block
384
- // otherwise add it to the active PyTorch block and reset
385
- if (in_prog_trt_blk_nodes.size () >= ctx->settings .min_block_size ) {
386
- finalizeNewBlock (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
387
- } else {
388
- LOG_DEBUG (
389
- " In progress TRT block does not meet minimum block size requirements ("
390
- << in_prog_trt_blk_nodes.size () << " , expected at least " << ctx->settings .min_block_size
391
- << " ), therefore folding into in progress PyTorch block" );
392
- in_prog_pyt_blk_nodes.insert (
393
- in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
470
+ // The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
471
+ // graph or add them to the active PyTorch block
472
+ if (cur_trt_nodes_uses.count (n)) {
473
+ // If there is an active TRT block that is valid segment and reset the active TRT block
474
+ // otherwise add it to the active PyTorch block and reset
475
+ if (in_prog_trt_blk_nodes.size () >= ctx->settings .min_block_size ) {
476
+ finalizeNewBlock (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
477
+ } else {
478
+ LOG_DEBUG (
479
+ " In progress TRT block does not meet minimum block size requirements ("
480
+ << in_prog_trt_blk_nodes.size () << " , expected at least " << ctx->settings .min_block_size
481
+ << " ), therefore folding into in progress PyTorch block" );
482
+ in_prog_pyt_blk_nodes.insert (
483
+ in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
484
+ cur_pyt_nodes_uses.insert (cur_trt_nodes_uses.begin (), cur_trt_nodes_uses.end ());
485
+ }
486
+ in_prog_trt_blk_nodes.clear ();
487
+ cur_trt_nodes_uses.clear ();
394
488
}
395
- in_prog_trt_blk_nodes.clear ();
396
489
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
397
490
// we shouldn't inject node for this block in dependency analysis process
398
491
if (n->kind () == torch::jit::prim::If) {
399
492
LOG_DEBUG (
400
493
" Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional" );
401
494
if (!in_prog_pyt_blk_nodes.empty ()) {
402
495
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
496
+ cur_pyt_nodes_uses.clear ();
403
497
}
404
498
auto cond_node = std::vector<torch::jit::Node*>{n};
405
499
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , cond_node);
500
+ segmented_blocks.back ().do_not_merge (true );
406
501
continue ;
407
502
} else if (n->kind () == torch::jit::prim::Loop) {
408
503
if (!in_prog_pyt_blk_nodes.empty ()) {
409
504
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
505
+ cur_pyt_nodes_uses.clear ();
410
506
}
411
507
if (checkLoopEvaluatable (n)) {
412
508
in_prog_trt_blk_nodes.push_back (n);
509
+ cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
413
510
} else {
414
511
auto loop_node = std::vector<torch::jit::Node*>{n};
415
512
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , loop_node);
513
+ segmented_blocks.back ().do_not_merge (true );
416
514
}
417
515
continue ;
418
516
}
419
517
in_prog_pyt_blk_nodes.push_back (n);
518
+ cur_pyt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
420
519
}
421
520
}
422
521
@@ -432,6 +531,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
432
531
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
433
532
}
434
533
534
+ segmented_blocks = merge_adjacent_segments_of_same_type (segmented_blocks);
435
535
ctx->partitioned_blocks .insert ({block, segmented_blocks});
436
536
return ;
437
537
}
0 commit comments