diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 565f58c677..28bfd0712c 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -90,14 +90,26 @@ std::vector getDependencyNodes( return stk; } -void find_nontensor_output_nodes( +// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related +// nodes +void fallback_graph_nontensor_in_out( torch::jit::Block* block, std::unordered_map& global_fallback_nodes) { + // fallback nodes that produce entire graph's nonTensor output for (auto i : block->outputs()) { if (!isTensor(i)) { global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR}); } } + + // fallback nodes that consume entire graph's nonTensor input + for (auto i : block->inputs()) { + if (!isTensor(i)) { + for (auto use : i->uses()) { + global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR}); + } + } + } } void find_all_fallback_nodes( @@ -202,6 +214,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo } } } + std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode(seg_block.g()); }); @@ -440,8 +453,9 @@ PartitionedGraph Partition( const PartitionInfo& partition_info, std::unordered_map& global_fallback_nodes) { LOG_DEBUG(partition_info); - // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output - find_nontensor_output_nodes(block, global_fallback_nodes); + // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor + // output + fallback_graph_nontensor_in_out(block, global_fallback_nodes); // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");