diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 4d74461454..1172ab4275 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo } } +// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback +bool checkLoopEvaluatable(torch::jit::Node* n) { + bool compile_to_trt = true; + for (auto bn : n->blocks()[0]->nodes()) { + if (bn->kind() == torch::jit::prim::Loop) { + compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn); + } else if (bn->kind() == torch::jit::prim::If) { + compile_to_trt = compile_to_trt && containNonTensorOutputs(bn); + } else { + compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn); + } + } + return compile_to_trt; +} + // Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback) // we use a map to indicate the reason why it's fallback to torch // For any node that's not explicitly fallback, we set it to run in TensorRT for now @@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { continue; } - if (!conversion::OpSupported(n)) { + if (n->kind() == torch::jit::prim::Loop && checkLoopEvaluatable(n)) { + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT); + } else if (!conversion::OpSupported(n)) { // If the op is not supported by the conversion phase it should run in PyTorch ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED); } else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) { @@ -269,7 +286,8 @@ void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) { dependency_nodes.end(), cur_partitioned_block[i].raw_nodes().begin(), cur_partitioned_block[i].raw_nodes().end()); - cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); + cur_partitioned_block[i] = + SegmentedBlock(cur_partitioned_block[i].get_id(), SegmentedBlock::kTensorRT, dependency_nodes); } } } @@ -336,21 +354,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) { return; } -// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback -bool checkLoopEvaluatable(torch::jit::Node* n) { - bool compile_to_trt = true; - for (auto bn : n->blocks()[0]->nodes()) { - if (bn->kind() == torch::jit::prim::Loop) { - compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn); - } else if (bn->kind() == torch::jit::prim::If) { - compile_to_trt = compile_to_trt && containNonTensorOutputs(bn); - } else { - compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn); - } - } - return compile_to_trt; -} - void finalizeNewBlock( PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, @@ -499,20 +502,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node); segmented_blocks.back().do_not_merge(true); continue; - } else if (n->kind() == torch::jit::prim::Loop) { - if (!in_prog_pyt_blk_nodes.empty()) { - finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); - cur_pyt_nodes_uses.clear(); - } - if (checkLoopEvaluatable(n)) { - in_prog_trt_blk_nodes.push_back(n); - cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); - } else { - auto loop_node = std::vector{n}; - finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node); - segmented_blocks.back().do_not_merge(true); - } - continue; } in_prog_pyt_blk_nodes.push_back(n); cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); diff --git a/core/partitioning/segmentedblock/SegmentedBlock.h b/core/partitioning/segmentedblock/SegmentedBlock.h index db5e8fedd9..66b07d0f90 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.h +++ b/core/partitioning/segmentedblock/SegmentedBlock.h @@ -98,6 +98,9 @@ struct SegmentedBlock { return in_types_; } + BlockID get_id() { + return id_; + } void update_id(BlockID new_id) { id_ = new_id; } diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 5f590fa5ab..3390cab98c 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -150,6 +150,97 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1); } +TEST(Partitioning, ResolveMultipleNonTensorInputsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + # TensorRT-intended Block + %16 : int = prim::Constant[value=8]() + %15 : int = prim::Constant[value=64]() + %13 : int = prim::Constant[value=0]() + %10 : int = prim::Constant[value=1]() + %self.linear.bias : Float(4096, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %self.linear.weight : Float(4096, 64, strides=[64, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %3 : int = prim::Constant[value=-1]() + %2 : int = prim::Constant[value=1]() + %x.5 : Tensor = aten::flatten(%x.1, %2, %3) + %4 : Tensor = aten::t(%self.linear.weight) + %6 : Tensor = aten::matmul(%x.5, %4) + %7 : Tensor = trt::const(%self.linear.bias) + %9 : Tensor = aten::add(%7, %6, %10) + %11 : int[] = aten::size(%9) # :13:9 + %12 : int = aten::__getitem__(%11, %13) + %shape.3 : int[] = prim::ListConstruct(%12, %15, %16, %16) + %x.13 : Tensor = aten::reshape(%9, %shape.3) + + # Torch-intended Block + %num_spatial_dims.2 : int = prim::Constant[value=2]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %10 : bool = prim::Constant[value=0]() + %conv1_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %conv1_weight : Float(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %6 : int = prim::Constant[value=1]() + %5 : int[] = prim::Constant[value=[1, 1]]() + %4 : int[] = prim::Constant[value=[2, 2]]() + %conv_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %conv_weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=]() + %input.16 : Tensor = aten::conv_transpose2d(%x.13, %conv_weight, %conv_bias, %4, %5, %5, %6, %5) + %7 : Tensor = aten::_convolution(%input.16, %conv1_weight, %conv1_bias, %5, %5, %5, %10, %11, %6, %10, %10, %10, %10) + %12 : int[] = aten::size(%7) + %96 : int = aten::len(%12) + %14 : int = aten::__range_length(%num_spatial_dims.2, %96, %6) + + # TensorRT-intended Block + %15 : float = prim::Constant[value=1e-05]() + %14 : float = prim::Constant[value=0.1]() + %13 : NoneType = prim::Constant() + %num_spatial_dims.2 : int = prim::Constant[value=2]() + %300 : int = prim::Constant[value=3]() + %345 : int = aten::sub(%300, %96) + %3 : int = aten::add(%345, %6) + %2 : bool = prim::Constant[value=1]() + %size_prods.2 : int = prim::Loop(%3, %2, %6) + block0(%loop : int, %size_prods.13 : int): + %i.3 : int = aten::__derive_index(%loop, %num_spatial_dims.2, %3) + %8 : int = aten::__getitem__(%12, %i.3) + %size_prods.15 : int = aten::mul(%size_prods.13, %8) + -> (%2, %size_prods.15) + %11 : Tensor = aten::instance_norm(%7, %13, %13, %13, %13, %2, %14, %15, %2) + return (%11))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get(), true); + + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({1, 64})); + + torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map; + std::unordered_map>> input_types; + for (size_t i = 0; i < g->inputs().size(); ++i) { + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); + } + + partitioning_info.collection_input_spec_map = inputs_map; + partitioning_info.forced_fallback_operators = {"aten::_convolution"}; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + ctx.input_types_map = input_types; + + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); + std::vector segmented_blocks = + ctx.partitioned_blocks.begin()->second; + + // For each TensorRT segmented block, verify that all inputs are of Tensor type + for (auto block : segmented_blocks) { + if (block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget::kTensorRT) { + for (auto input : block.raw_inputs()) + ASSERT_TRUE(input->type()->isSubtypeOf(c10::TensorType::get())); + } + } +} + TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { const auto graph = R"IR( graph(%0 : Float(1, 3, 16, 16, strides=[768, 256, 16, 1]),