Skip to content

Commit 3a33b6e

Browse files
committed
[feat] Add dependency awareness to torch-trt partitioning (#40)
Adds a heuristic to torch-trt partitioning's segmentation to avoid materializing segments until we hit a dependency of that segment. This can significantly reduce the number of segments/engines in cases where the linear traversal of torchscipt nodes would otherwise produce alternating torch and TRT segments which are not dependent on each-other Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
1 parent e608bc7 commit 3a33b6e

File tree

5 files changed

+466
-329
lines changed

5 files changed

+466
-329
lines changed

core/partitioning/partitioning.cpp

Lines changed: 118 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,34 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
111111
}
112112
}
113113

114+
std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
115+
std::set<torch::jit::Node*> dependent_nodes;
116+
for (auto val : n->outputs()) {
117+
for (auto use : val->uses()) {
118+
dependent_nodes.insert(use.user);
119+
}
120+
}
121+
if (const auto* schema = n->maybeSchema()) {
122+
for (size_t i = 0; i < n->inputs().size(); ++i) {
123+
const at::AliasInfo* formal = schema->arguments()[i].alias_info();
124+
if (formal && formal->isWrite()) {
125+
for (auto use : n->inputs()[i]->uses()) {
126+
torch::jit::Node* use_node = use.user;
127+
if (use_node->isAfter(n)) {
128+
dependent_nodes.insert(use_node);
129+
}
130+
}
131+
}
132+
}
133+
}
134+
return dependent_nodes;
135+
}
136+
114137
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
115138
std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
116139
auto nodes = block->nodes();
117140
std::vector<torch::jit::Node*> cur_trt_nodes;
141+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
118142
std::vector<torch::jit::Node*> min_block_fallback_nodes;
119143
for (const auto n : nodes) {
120144
if (n->kind() == torch::jit::prim::Constant) {
@@ -124,11 +148,16 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
124148
// check if current node fallback or not
125149
if (!ctx->shouldNodeRunInTorch(n)) {
126150
cur_trt_nodes.push_back(n);
151+
auto dependent_nodes = getDependentNodes(n);
152+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
127153
} else {
128-
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
129-
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
154+
if (cur_trt_nodes_uses.count(n)) {
155+
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
156+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
157+
}
158+
cur_trt_nodes.clear();
159+
cur_trt_nodes_uses.clear();
130160
}
131-
cur_trt_nodes.clear();
132161
}
133162
}
134163
if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
@@ -355,6 +384,59 @@ void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
355384
setMinBlockFallbackNodes(ctx, block);
356385
}
357386

387+
void merge_adjacent_segments_list_in_new_partition(
388+
PartitionedGraph& original_partition,
389+
PartitionedGraph& new_partition,
390+
SegmentedBlock::SegmentedBlockTarget& segment_kind,
391+
std::vector<size_t>& same_type_segment_idx) {
392+
TORCHTRT_CHECK(!same_type_segment_idx.empty(), "Unable to merge empty segment list");
393+
if (same_type_segment_idx.size() == 1) {
394+
new_partition.push_back(original_partition[same_type_segment_idx[0]]);
395+
} else {
396+
auto first_idx = same_type_segment_idx[0];
397+
for (size_t i = 1; i < same_type_segment_idx.size(); ++i) {
398+
TORCHTRT_CHECK(
399+
same_type_segment_idx[i] == (first_idx + i),
400+
"Unable to merge non-sequential segments: " << same_type_segment_idx);
401+
}
402+
LOG_DEBUG(
403+
"Merging adjacent " << SegmentedBlock::target_to_str(segment_kind) << " segments: " << same_type_segment_idx);
404+
std::vector<torch::jit::Node*> nodes;
405+
for (auto segment_to_merge : same_type_segment_idx) {
406+
const auto& merge_nodes = original_partition[segment_to_merge].raw_nodes();
407+
nodes.insert(nodes.end(), merge_nodes.begin(), merge_nodes.end());
408+
}
409+
new_partition.emplace_back(segment_kind, nodes);
410+
}
411+
}
412+
413+
PartitionedGraph merge_adjacent_segments_of_same_type(PartitionedGraph& original_partition) {
414+
PartitionedGraph new_partition;
415+
SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch;
416+
std::vector<size_t> same_type_segment_idx;
417+
for (size_t i = 0UL; i < original_partition.size(); ++i) {
418+
auto& segment = original_partition[i];
419+
if (same_type_segment_idx.empty()) {
420+
segment_kind = segment.target();
421+
} else if (segment_kind != segment.target() || segment.do_not_merge()) {
422+
merge_adjacent_segments_list_in_new_partition(
423+
original_partition, new_partition, segment_kind, same_type_segment_idx);
424+
same_type_segment_idx.clear();
425+
segment_kind = segment.target();
426+
}
427+
if (segment.do_not_merge()) {
428+
new_partition.push_back(segment);
429+
} else {
430+
same_type_segment_idx.push_back(i);
431+
}
432+
}
433+
if (!same_type_segment_idx.empty()) {
434+
merge_adjacent_segments_list_in_new_partition(
435+
original_partition, new_partition, segment_kind, same_type_segment_idx);
436+
}
437+
return new_partition;
438+
}
439+
358440
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
359441
// Find all the fallback nodes and build execution decision LUT for all nodes
360442
setNodeExecutorLUT(ctx, block);
@@ -365,58 +447,75 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
365447
PartitionedGraph segmented_blocks;
366448

367449
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
450+
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
451+
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
368452
for (const auto n : nodes) {
369453
// Skip constant nodes as they are resources for both kinds of modules
370454
if (n->kind() == torch::jit::prim::Constant) {
371455
continue;
372456
}
457+
auto dependent_nodes = getDependentNodes(n);
373458
// the outputs of trt subgraph shouldn't be collections
374459
if (ctx->shouldNodeRunInTensorRT(n)) {
375460
in_prog_trt_blk_nodes.push_back(n);
461+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
376462

377-
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
378-
// block then segment and reset the active PyTorch block
379-
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) {
463+
// If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize
464+
// those dependencies in the graph
465+
if (cur_pyt_nodes_uses.count(n)) {
380466
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
467+
cur_pyt_nodes_uses.clear();
381468
}
382469
} else {
383-
// If there is an active TRT block that is valid segment and reset the active TRT block
384-
// otherwise add it to the active PyTorch block and reset
385-
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
386-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
387-
} else {
388-
LOG_DEBUG(
389-
"In progress TRT block does not meet minimum block size requirements ("
390-
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
391-
<< "), therefore folding into in progress PyTorch block");
392-
in_prog_pyt_blk_nodes.insert(
393-
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
470+
// The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the
471+
// graph or add them to the active PyTorch block
472+
if (cur_trt_nodes_uses.count(n)) {
473+
// If there is an active TRT block that is valid segment and reset the active TRT block
474+
// otherwise add it to the active PyTorch block and reset
475+
if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
476+
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
477+
} else {
478+
LOG_DEBUG(
479+
"In progress TRT block does not meet minimum block size requirements ("
480+
<< in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
481+
<< "), therefore folding into in progress PyTorch block");
482+
in_prog_pyt_blk_nodes.insert(
483+
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
484+
cur_pyt_nodes_uses.insert(cur_trt_nodes_uses.begin(), cur_trt_nodes_uses.end());
485+
}
486+
in_prog_trt_blk_nodes.clear();
487+
cur_trt_nodes_uses.clear();
394488
}
395-
in_prog_trt_blk_nodes.clear();
396489
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
397490
// we shouldn't inject node for this block in dependency analysis process
398491
if (n->kind() == torch::jit::prim::If) {
399492
LOG_DEBUG(
400493
"Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
401494
if (!in_prog_pyt_blk_nodes.empty()) {
402495
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
496+
cur_pyt_nodes_uses.clear();
403497
}
404498
auto cond_node = std::vector<torch::jit::Node*>{n};
405499
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
500+
segmented_blocks.back().do_not_merge(true);
406501
continue;
407502
} else if (n->kind() == torch::jit::prim::Loop) {
408503
if (!in_prog_pyt_blk_nodes.empty()) {
409504
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
505+
cur_pyt_nodes_uses.clear();
410506
}
411507
if (checkLoopEvaluatable(n)) {
412508
in_prog_trt_blk_nodes.push_back(n);
509+
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
413510
} else {
414511
auto loop_node = std::vector<torch::jit::Node*>{n};
415512
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
513+
segmented_blocks.back().do_not_merge(true);
416514
}
417515
continue;
418516
}
419517
in_prog_pyt_blk_nodes.push_back(n);
518+
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
420519
}
421520
}
422521

@@ -432,6 +531,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
432531
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
433532
}
434533

534+
segmented_blocks = merge_adjacent_segments_of_same_type(segmented_blocks);
435535
ctx->partitioned_blocks.insert({block, segmented_blocks});
436536
return;
437537
}

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ struct SegmentedBlock {
9494
return target_;
9595
}
9696

97+
bool do_not_merge(void) const {
98+
return do_not_merge_;
99+
}
100+
101+
void do_not_merge(bool x) {
102+
do_not_merge_ = x;
103+
}
104+
97105
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
98106

99107
private:
@@ -106,6 +114,7 @@ struct SegmentedBlock {
106114
std::vector<torch::jit::Node*> nodes_;
107115
std::shared_ptr<torch::jit::Graph> g_;
108116
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
117+
bool do_not_merge_ = false;
109118
};
110119

111120
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);

tests/core/partitioning/test_conditionals.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
4040

4141
auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);
4242

43-
ASSERT_TRUE(conditional_engines_count == 2);
43+
ASSERT_TRUE(conditional_engines_count == 1);
4444
}
4545

4646
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
204204
}));
205205
}
206206
}
207-
ASSERT_TRUE(trt_block_cnt == 2 && torch_block_cnt == 2);
207+
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1);
208208
}
209209

210210
TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {

0 commit comments

Comments
 (0)