Skip to content

Commit 248793d

Browse files
authored
Merge pull request #1691 from pytorch/fix_loop_fallback
fix: fix the prim::Loop fallback issue
2 parents fce0a01 + 79dc360 commit 248793d

File tree

4 files changed

+119
-34
lines changed

4 files changed

+119
-34
lines changed

core/partitioning/partitioning.cpp

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
4747
}
4848
}
4949

50+
// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
51+
bool checkLoopEvaluatable(torch::jit::Node* n) {
52+
bool compile_to_trt = true;
53+
for (auto bn : n->blocks()[0]->nodes()) {
54+
if (bn->kind() == torch::jit::prim::Loop) {
55+
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
56+
} else if (bn->kind() == torch::jit::prim::If) {
57+
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
58+
} else {
59+
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
60+
}
61+
}
62+
return compile_to_trt;
63+
}
64+
5065
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
5166
// we use a map to indicate the reason why it's fallback to torch
5267
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
@@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
5974
continue;
6075
}
6176

62-
if (!conversion::OpSupported(n)) {
77+
if (n->kind() == torch::jit::prim::Loop && checkLoopEvaluatable(n)) {
78+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
79+
} else if (!conversion::OpSupported(n)) {
6380
// If the op is not supported by the conversion phase it should run in PyTorch
6481
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
6582
} else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
@@ -269,7 +286,8 @@ void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
269286
dependency_nodes.end(),
270287
cur_partitioned_block[i].raw_nodes().begin(),
271288
cur_partitioned_block[i].raw_nodes().end());
272-
cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes);
289+
cur_partitioned_block[i] =
290+
SegmentedBlock(cur_partitioned_block[i].get_id(), SegmentedBlock::kTensorRT, dependency_nodes);
273291
}
274292
}
275293
}
@@ -336,21 +354,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
336354
return;
337355
}
338356

339-
// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
340-
bool checkLoopEvaluatable(torch::jit::Node* n) {
341-
bool compile_to_trt = true;
342-
for (auto bn : n->blocks()[0]->nodes()) {
343-
if (bn->kind() == torch::jit::prim::Loop) {
344-
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
345-
} else if (bn->kind() == torch::jit::prim::If) {
346-
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
347-
} else {
348-
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
349-
}
350-
}
351-
return compile_to_trt;
352-
}
353-
354357
void finalizeNewBlock(
355358
PartitionedGraph& g,
356359
SegmentedBlock::SegmentedBlockTarget kind,
@@ -499,20 +502,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
499502
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
500503
segmented_blocks.back().do_not_merge(true);
501504
continue;
502-
} else if (n->kind() == torch::jit::prim::Loop) {
503-
if (!in_prog_pyt_blk_nodes.empty()) {
504-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
505-
cur_pyt_nodes_uses.clear();
506-
}
507-
if (checkLoopEvaluatable(n)) {
508-
in_prog_trt_blk_nodes.push_back(n);
509-
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
510-
} else {
511-
auto loop_node = std::vector<torch::jit::Node*>{n};
512-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
513-
segmented_blocks.back().do_not_merge(true);
514-
}
515-
continue;
516505
}
517506
in_prog_pyt_blk_nodes.push_back(n);
518507
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
1515
}
1616

1717
void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
18-
if (!b->owningNode() || b->owningNode()->kind() != torch::jit::prim::Loop) {
19-
original_blocks.push_back(b);
20-
}
18+
if (b->owningNode() && b->owningNode()->kind() == torch::jit::prim::Loop)
19+
return;
20+
21+
original_blocks.push_back(b);
22+
2123
for (const auto n : b->nodes()) {
2224
if (n->kind() == torch::jit::prim::Constant) {
2325
continue;

core/partitioning/segmentedblock/SegmentedBlock.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ struct SegmentedBlock {
9898
return in_types_;
9999
}
100100

101+
BlockID get_id() {
102+
return id_;
103+
}
101104
void update_id(BlockID new_id) {
102105
id_ = new_id;
103106
}

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,97 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
150150
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1);
151151
}
152152

153+
TEST(Partitioning, ResolveMultipleNonTensorInputsCorrectly) {
154+
const auto graph = R"IR(
155+
graph(%x.1 : Tensor):
156+
# TensorRT-intended Block
157+
%16 : int = prim::Constant[value=8]()
158+
%15 : int = prim::Constant[value=64]()
159+
%13 : int = prim::Constant[value=0]()
160+
%10 : int = prim::Constant[value=1]()
161+
%self.linear.bias : Float(4096, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
162+
%self.linear.weight : Float(4096, 64, strides=[64, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
163+
%3 : int = prim::Constant[value=-1]()
164+
%2 : int = prim::Constant[value=1]()
165+
%x.5 : Tensor = aten::flatten(%x.1, %2, %3)
166+
%4 : Tensor = aten::t(%self.linear.weight)
167+
%6 : Tensor = aten::matmul(%x.5, %4)
168+
%7 : Tensor = trt::const(%self.linear.bias)
169+
%9 : Tensor = aten::add(%7, %6, %10)
170+
%11 : int[] = aten::size(%9) # <string>:13:9
171+
%12 : int = aten::__getitem__(%11, %13)
172+
%shape.3 : int[] = prim::ListConstruct(%12, %15, %16, %16)
173+
%x.13 : Tensor = aten::reshape(%9, %shape.3)
174+
175+
# Torch-intended Block
176+
%num_spatial_dims.2 : int = prim::Constant[value=2]()
177+
%11 : int[] = prim::Constant[value=[0, 0]]()
178+
%10 : bool = prim::Constant[value=0]()
179+
%conv1_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
180+
%conv1_weight : Float(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
181+
%6 : int = prim::Constant[value=1]()
182+
%5 : int[] = prim::Constant[value=[1, 1]]()
183+
%4 : int[] = prim::Constant[value=[2, 2]]()
184+
%conv_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
185+
%conv_weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
186+
%input.16 : Tensor = aten::conv_transpose2d(%x.13, %conv_weight, %conv_bias, %4, %5, %5, %6, %5)
187+
%7 : Tensor = aten::_convolution(%input.16, %conv1_weight, %conv1_bias, %5, %5, %5, %10, %11, %6, %10, %10, %10, %10)
188+
%12 : int[] = aten::size(%7)
189+
%96 : int = aten::len(%12)
190+
%14 : int = aten::__range_length(%num_spatial_dims.2, %96, %6)
191+
192+
# TensorRT-intended Block
193+
%15 : float = prim::Constant[value=1e-05]()
194+
%14 : float = prim::Constant[value=0.1]()
195+
%13 : NoneType = prim::Constant()
196+
%num_spatial_dims.2 : int = prim::Constant[value=2]()
197+
%300 : int = prim::Constant[value=3]()
198+
%345 : int = aten::sub(%300, %96)
199+
%3 : int = aten::add(%345, %6)
200+
%2 : bool = prim::Constant[value=1]()
201+
%size_prods.2 : int = prim::Loop(%3, %2, %6)
202+
block0(%loop : int, %size_prods.13 : int):
203+
%i.3 : int = aten::__derive_index(%loop, %num_spatial_dims.2, %3)
204+
%8 : int = aten::__getitem__(%12, %i.3)
205+
%size_prods.15 : int = aten::mul(%size_prods.13, %8)
206+
-> (%2, %size_prods.15)
207+
%11 : Tensor = aten::instance_norm(%7, %13, %13, %13, %13, %2, %14, %15, %2)
208+
return (%11))IR";
209+
210+
auto g = std::make_shared<torch::jit::Graph>();
211+
torch::jit::parseIR(graph, g.get(), true);
212+
213+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
214+
partitioning_info.enabled = true;
215+
std::vector<torch_tensorrt::core::ir::Input> inputs;
216+
inputs.push_back(torch_tensorrt::core::ir::Input({1, 64}));
217+
218+
torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map;
219+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
220+
for (size_t i = 0; i < g->inputs().size(); ++i) {
221+
inputs_map.insert({g->inputs()[i], {inputs[i]}});
222+
input_types.insert({g->inputs()[i], {{at::kFloat}}});
223+
}
224+
225+
partitioning_info.collection_input_spec_map = inputs_map;
226+
partitioning_info.forced_fallback_operators = {"aten::_convolution"};
227+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
228+
ctx.input_types_map = input_types;
229+
230+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
231+
torch_tensorrt::core::partitioning::partition(&ctx);
232+
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
233+
ctx.partitioned_blocks.begin()->second;
234+
235+
// For each TensorRT segmented block, verify that all inputs are of Tensor type
236+
for (auto block : segmented_blocks) {
237+
if (block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget::kTensorRT) {
238+
for (auto input : block.raw_inputs())
239+
ASSERT_TRUE(input->type()->isSubtypeOf(c10::TensorType::get()));
240+
}
241+
}
242+
}
243+
153244
TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
154245
const auto graph = R"IR(
155246
graph(%0 : Float(1, 3, 16, 16, strides=[768, 256, 16, 1]),

0 commit comments

Comments
 (0)