Skip to content

Commit 6f30e4b

Browse files
authored
Merge pull request #1304 from mfeliz-cruise/michael.feliz/dependency_aware_partitioning
[feat] Add dependency awareness to torch-trt partitioning
2 parents 80b7189 + 56ae9f6 commit 6f30e4b

File tree

6 files changed

+308
-23
lines changed

6 files changed

+308
-23
lines changed

core/partitioning/partitioning.cpp

Lines changed: 118 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,34 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
111111
}
112112
}
113113

114+
std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
115+
std::set<torch::jit::Node*> dependent_nodes;
116+
for (auto val : n->outputs()) {
117+
for (auto use : val->uses()) {
118+
dependent_nodes.insert(use.user);
119+
}
120+
}
121+
if (const auto* schema = n->maybeSchema()) {
122+
for (size_t i = 0; i < n->inputs().size(); ++i) {
123+
const at::AliasInfo* formal = schema->arguments()[i].alias_info();
124+
if (formal && formal->isWrite()) {
125+
for (auto use : n->inputs()[i]->uses()) {
126+
torch::jit::Node* use_node = use.user;
127+
if (use_node->isAfter(n)) {
128+
dependent_nodes.insert(use_node);
129+
}
130+
}
131+
}
132+
}
133+
}
134+
return dependent_nodes;
135+
}
136+
114137
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
115138
std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
116139
auto nodes = block->nodes();
117140
std::vector<torch::jit::Node*> cur_trt_nodes;
141+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
118142
std::vector<torch::jit::Node*> min_block_fallback_nodes;
119143
for (const auto n : nodes) {
120144
if (n->kind() == torch::jit::prim::Constant) {
@@ -124,11 +148,16 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
124148
// check if current node fallback or not
125149
if (!ctx->shouldNodeRunInTorch(n)) {
126150
cur_trt_nodes.push_back(n);
151+
auto dependent_nodes = getDependentNodes(n);
152+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
127153
} else {
128-
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
129-
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
154+
if (cur_trt_nodes_uses.count(n)) {
155+
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
156+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
157+
}
158+
cur_trt_nodes.clear();
159+
cur_trt_nodes_uses.clear();
130160
}
131-
cur_trt_nodes.clear();
132161
}
133162
}
134163
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
@@ -355,6 +384,59 @@ void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
355384
setMinBlockFallbackNodes(ctx, block);
356385
}
357386

387+
void merge_adjacent_segments_list_in_new_partition(
388+
PartitionedGraph& original_partition,
389+
PartitionedGraph& new_partition,
390+
SegmentedBlock::SegmentedBlockTarget& segment_kind,
391+
std::vector<size_t>& same_type_segment_idx) {
392+
TORCHTRT_CHECK(!same_type_segment_idx.empty(), "Unable to merge empty segment list");
393+
if (same_type_segment_idx.size() == 1) {
394+
new_partition.push_back(original_partition[same_type_segment_idx[0]]);
395+
} else {
396+
auto first_idx = same_type_segment_idx[0];
397+
for (size_t i = 1; i < same_type_segment_idx.size(); ++i) {
398+
TORCHTRT_CHECK(
399+
same_type_segment_idx[i] == (first_idx + i),
400+
"Unable to merge non-sequential segments: " << same_type_segment_idx);
401+
}
402+
LOG_DEBUG(
403+
"Merging adjacent " << SegmentedBlock::target_to_str(segment_kind) << " segments: " << same_type_segment_idx);
404+
std::vector<torch::jit::Node*> nodes;
405+
for (auto segment_to_merge : same_type_segment_idx) {
406+
const auto& merge_nodes = original_partition[segment_to_merge].raw_nodes();
407+
nodes.insert(nodes.end(), merge_nodes.begin(), merge_nodes.end());
408+
}
409+
new_partition.emplace_back(segment_kind, nodes);
410+
}
411+
}
412+
413+
PartitionedGraph merge_adjacent_segments_of_same_type(PartitionedGraph& original_partition) {
414+
PartitionedGraph new_partition;
415+
SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch;
416+
std::vector<size_t> same_type_segment_idx;
417+
for (size_t i = 0UL; i < original_partition.size(); ++i) {
418+
auto& segment = original_partition[i];
419+
if (same_type_segment_idx.empty()) {
420+
segment_kind = segment.target();
421+
} else if (segment_kind != segment.target() || segment.do_not_merge()) {
422+
merge_adjacent_segments_list_in_new_partition(
423+
original_partition, new_partition, segment_kind, same_type_segment_idx);
424+
same_type_segment_idx.clear();
425+
segment_kind = segment.target();
426+
}
427+
if (segment.do_not_merge()) {
428+
new_partition.push_back(segment);
429+
} else {
430+
same_type_segment_idx.push_back(i);
431+
}
432+
}
433+
if (!same_type_segment_idx.empty()) {
434+
merge_adjacent_segments_list_in_new_partition(
435+
original_partition, new_partition, segment_kind, same_type_segment_idx);
436+
}
437+
return new_partition;
438+
}
439+
358440
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
359441
// Find all the fallback nodes and build execution decision LUT for all nodes
360442
setNodeExecutorLUT(ctx, block);
@@ -365,58 +447,75 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
365447
PartitionedGraph segmented_blocks;
366448

367449
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
450+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
451+
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
368452
for (const auto n : nodes) {
369453
// Skip constant nodes as they are resources for both kinds of modules
370454
if (n->kind() == torch::jit::prim::Constant) {
371455
continue;
372456
}
457+
auto dependent_nodes = getDependentNodes(n);
373458
// the outputs of trt subgraph shouldn't be collections
374459
if (ctx->shouldNodeRunInTensorRT(n)) {
375460
in_prog_trt_blk_nodes.push_back(n);
461+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
376462

377-
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
378-
// block then segment and reset the active PyTorch block
379-
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) {
463+
// If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
464+
// those dependencies in the graph
465+
if (cur_pyt_nodes_uses.count(n)) {
380466
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
467+
cur_pyt_nodes_uses.clear();
381468
}
382469
} else {
383-
// If there is an active TRT block that is valid segment and reset the active TRT block
384-
// otherwise add it to the active PyTorch block and reset
385-
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
386-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
387-
} else {
388-
LOG_DEBUG(
389-
"In progress TRT block does not meet minimum block size requirements ("
390-
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
391-
<< "), therefore folding into in progress PyTorch block");
392-
in_prog_pyt_blk_nodes.insert(
393-
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
470+
// The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
471+
// graph or add them to the active PyTorch block
472+
if (cur_trt_nodes_uses.count(n)) {
473+
// If there is an active TRT block that is valid segment and reset the active TRT block
474+
// otherwise add it to the active PyTorch block and reset
475+
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
476+
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
477+
} else {
478+
LOG_DEBUG(
479+
"In progress TRT block does not meet minimum block size requirements ("
480+
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
481+
<< "), therefore folding into in progress PyTorch block");
482+
in_prog_pyt_blk_nodes.insert(
483+
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
484+
cur_pyt_nodes_uses.insert(cur_trt_nodes_uses.begin(), cur_trt_nodes_uses.end());
485+
}
486+
in_prog_trt_blk_nodes.clear();
487+
cur_trt_nodes_uses.clear();
394488
}
395-
in_prog_trt_blk_nodes.clear();
396489
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
397490
// we shouldn't inject node for this block in dependency analysis process
398491
if (n->kind() == torch::jit::prim::If) {
399492
LOG_DEBUG(
400493
"Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
401494
if (!in_prog_pyt_blk_nodes.empty()) {
402495
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
496+
cur_pyt_nodes_uses.clear();
403497
}
404498
auto cond_node = std::vector<torch::jit::Node*>{n};
405499
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
500+
segmented_blocks.back().do_not_merge(true);
406501
continue;
407502
} else if (n->kind() == torch::jit::prim::Loop) {
408503
if (!in_prog_pyt_blk_nodes.empty()) {
409504
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
505+
cur_pyt_nodes_uses.clear();
410506
}
411507
if (checkLoopEvaluatable(n)) {
412508
in_prog_trt_blk_nodes.push_back(n);
509+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
413510
} else {
414511
auto loop_node = std::vector<torch::jit::Node*>{n};
415512
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
513+
segmented_blocks.back().do_not_merge(true);
416514
}
417515
continue;
418516
}
419517
in_prog_pyt_blk_nodes.push_back(n);
518+
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
420519
}
421520
}
422521

@@ -432,6 +531,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
432531
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
433532
}
434533

534+
segmented_blocks = merge_adjacent_segments_of_same_type(segmented_blocks);
435535
ctx->partitioned_blocks.insert({block, segmented_blocks});
436536
return;
437537
}

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ struct SegmentedBlock {
9494
return target_;
9595
}
9696

97+
bool do_not_merge(void) const {
98+
return do_not_merge_;
99+
}
100+
101+
void do_not_merge(bool x) {
102+
do_not_merge_ = x;
103+
}
104+
97105
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
98106

99107
private:
@@ -106,6 +114,7 @@ struct SegmentedBlock {
106114
std::vector<torch::jit::Node*> nodes_;
107115
std::shared_ptr<torch::jit::Graph> g_;
108116
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
117+
bool do_not_merge_ = false;
109118
};
110119

111120
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);

docsrc/contributors/partitioning.rst

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,151 @@ To enable automatic fallback feature, you can set following attributes in Python
9191
cfg.torch_executed_ops.push_back("aten::relu");
9292
auto trt_mod = torchtrt::ts::compile(mod, cfg);
9393
auto out = trt_mod.forward({in});
94+
95+
Dependency Aware Partitioning
96+
====================
97+
During segmentation, Torch-TensorRT uses a dependency graph of the input TorchScript nodes to reduce the number of segments created. Consider this example from test Partitioning.SegmentModelWithDependencyAwareness in `tests/core/partitioning/test_segmentation.cpp <https://github.com/pytorch/TensorRT/blob/master/tests/core/partitioning/test_segmentation.cpp>`_
98+
99+
.. code-block:: none
100+
101+
graph(%x : Tensor, %y : Tensor):
102+
%3 : int = prim::Constant[value=0]()
103+
%20 : int = prim::Constant[value=1]()
104+
%add : Tensor = aten::add(%x, %y, %20)
105+
%x_lgamma : Tensor = aten::lgamma(%x)
106+
%mul : Tensor = aten::mul(%x, %y)
107+
%y_lgamma : Tensor = aten::lgamma(%y)
108+
%div : Tensor = aten::div(%x, %y)
109+
%div_lgamma : Tensor = aten::lgamma(%div)
110+
%27 : Tensor[] = prim::ListConstruct(%x_lgamma, %y_lgamma, %div_lgamma, %add, %mul)
111+
%12 : Tensor = aten::cat(%27, %3)
112+
return (%12)
113+
114+
In this graph `aten::lgamma` is not supported by conversion and must be partitioned in a Torch fallback segment. If Torch-TensorRT uses a greedy segmentation strategy that traverses nodes in the input graph in order and gathers ops with the same target (TensorRT or Torch) into a segment until it encounters an op with a different target, the resulting partition includes 7 segments, many with just a single op.
115+
116+
.. code-block:: none
117+
118+
Segment Block @0:
119+
Target: TensorRT
120+
121+
Graph: graph(%x : Tensor,
122+
%y : Tensor):
123+
%3 : int = prim::Constant[value=1]()
124+
%0 : Tensor = aten::add(%x, %y, %3)
125+
return ()
126+
127+
Segment Block @1:
128+
Target: Torch
129+
130+
Graph: graph(%x : Tensor):
131+
%0 : Tensor = aten::lgamma(%x)
132+
return ()
133+
134+
Segment Block @2:
135+
Target: TensorRT
136+
137+
Graph: graph(%x : Tensor,
138+
%y : Tensor):
139+
%0 : Tensor = aten::mul(%x, %y)
140+
return ()
141+
142+
Segment Block @3:
143+
Target: Torch
144+
145+
Graph: graph(%y : Tensor):
146+
%0 : Tensor = aten::lgamma(%y)
147+
return ()
148+
149+
Segment Block @4:
150+
Target: TensorRT
151+
152+
Graph: graph(%x : Tensor,
153+
%y : Tensor):
154+
%0 : Tensor = aten::div(%x, %y)
155+
return ()
156+
157+
Segment Block @5:
158+
Target: Torch
159+
160+
Graph: graph(%1 : Tensor):
161+
%0 : Tensor = aten::lgamma(%1)
162+
return ()
163+
164+
Segment Block @6:
165+
Target: TensorRT
166+
167+
Graph: graph(%1 : Tensor,
168+
%2 : Tensor,
169+
%3 : Tensor,
170+
%4 : Tensor,
171+
%5 : Tensor):
172+
%7 : int = prim::Constant[value=0]()
173+
%0 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %5)
174+
%6 : Tensor = aten::cat(%0, %7)
175+
return ()
176+
177+
This partition is valid, but the segmentation is suboptimal. These arithmetic ops and `aten::lgamma` ops are each split into their own segment as we alternate between Torch and TensorRT targets in the linear traversal of the graph.
178+
179+
.. code-block:: none
180+
181+
%add : Tensor = aten::add(%x, %y, %20)
182+
%x_lgamma : Tensor = aten::lgamma(%x)
183+
%mul : Tensor = aten::mul(%x, %y)
184+
%y_lgamma : Tensor = aten::lgamma(%y)
185+
%div : Tensor = aten::div(%x, %y)
186+
%div_lgamma : Tensor = aten::lgamma(%div)
187+
188+
Each of the arithmetic ops in this segment is only dependent on constants and the inputs `%x` and `%y`. The `aten::lgamma` ops are dependent on the inputs `%x`, `%y` and the output of the `aten::div`. This means that we could rewrite this portion of the input graph as below without changing the behavior of the graph. This reordered series of ops could be cleanly partitioned into just 2 segments using the greedy segmentation approach described above.
189+
190+
.. code-block:: none
191+
192+
%add : Tensor = aten::add(%x, %y, %20)
193+
%mul : Tensor = aten::mul(%x, %y)
194+
%div : Tensor = aten::div(%x, %y)
195+
%x_lgamma : Tensor = aten::lgamma(%x)
196+
%y_lgamma : Tensor = aten::lgamma(%y)
197+
%div_lgamma : Tensor = aten::lgamma(%div)
198+
199+
By adding awareness of the dependencies between ops to the basic greedy segmentation approach we can achieve the same partition without rewriting the graph. Now we will maintain both Torch and TensorRT targeted segments at the same time as we traverse the graph. We will only finalize a segment once we hit an op that is both dependent on an op in the segment and has a different target. This will allow the partition to create larger segments by reordering nodes across the segment boundary while guaranteeing that we will not modify the behavior of the graph by reordering nodes relative to their dependencies.
200+
In this example we will collect the arithmetic ops in a TensorRT segment and the `aten::lgamma` ops in a Torch segment. When we encounter the `%div_lgamma : Tensor = aten::lgamma(%div)` op we can see it is dependent on `%div : Tensor = aten::div(%x, %y)` in the current TensorRT segment. This triggers finalization of the TensorRT segment containing the `aten::div` op to guarantee it will appear before its dependency in the final partition. The Torch segment containing the `aten::lgamma` op is finalized when we encounter the `prim::ListConstruct` op which targets TensorRT and is dependent on the results of the `aten::lgamma` ops.
201+
202+
.. code-block:: none
203+
204+
Segment Block @0:
205+
Target: TensorRT
206+
207+
Graph: graph(%x : Tensor,
208+
%y : Tensor):
209+
%3 : int = prim::Constant[value=1]()
210+
%0 : Tensor = aten::add(%x, %y, %3)
211+
%4 : Tensor = aten::mul(%x, %y)
212+
%5 : Tensor = aten::div(%x, %y)
213+
return ()
214+
215+
Segment Block @1:
216+
Target: Torch
217+
218+
Graph: graph(%x : Tensor,
219+
%y : Tensor,
220+
%5 : Tensor):
221+
%0 : Tensor = aten::lgamma(%x)
222+
%2 : Tensor = aten::lgamma(%y)
223+
%4 : Tensor = aten::lgamma(%5)
224+
return ()
225+
226+
Segment Block @2:
227+
Target: TensorRT
228+
229+
Graph: graph(%1 : Tensor,
230+
%2 : Tensor,
231+
%3 : Tensor,
232+
%4 : Tensor,
233+
%5 : Tensor):
234+
%7 : int = prim::Constant[value=0]()
235+
%0 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %5)
236+
%6 : Tensor = aten::cat(%0, %7)
237+
return ()
238+
239+
In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition.
240+
The merge segments step identifies a list of segments that are adjacent in the graph, have the same target, and are not marked as `do_not_merge`. The nodes from these segments will be combined into a single new segment that will replace the merged segments in the partition.
241+
The `do_not_merge` marking is used to prevent merging of segments created for conditional nodes and loops that are handled as special cases in graph stitching and should not be merged with adjacent segments of the same type.

tests/core/partitioning/test_conditionals.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
4040

4141
auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);
4242

43-
ASSERT_TRUE(conditional_engines_count == 2);
43+
ASSERT_TRUE(conditional_engines_count == 1);
4444
}
4545

4646
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {

0 commit comments

Comments
 (0)