Skip to content

fix: Issue in non-Tensor Input Resolution #1617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
}
}

// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
bool checkLoopEvaluatable(torch::jit::Node* n) {
bool compile_to_trt = true;
for (auto bn : n->blocks()[0]->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
} else if (bn->kind() == torch::jit::prim::If) {
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
} else {
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
}
}
return compile_to_trt;
}

// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
// we use a map to indicate the reason why it's fallback to torch
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
Expand All @@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
continue;
}

if (!conversion::OpSupported(n)) {
if (n->kind() == torch::jit::prim::Loop && checkLoopEvaluatable(n)) {
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
} else if (!conversion::OpSupported(n)) {
// If the op is not supported by the conversion phase it should run in PyTorch
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
} else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
Expand Down Expand Up @@ -269,7 +286,8 @@ void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
dependency_nodes.end(),
cur_partitioned_block[i].raw_nodes().begin(),
cur_partitioned_block[i].raw_nodes().end());
cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes);
cur_partitioned_block[i] =
SegmentedBlock(cur_partitioned_block[i].get_id(), SegmentedBlock::kTensorRT, dependency_nodes);
Comment on lines +289 to +290
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to ensure the block IDs of segmented blocks stay in order despite resolution of non-Tensor inputs.

}
}
}
Expand Down Expand Up @@ -336,21 +354,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
return;
}

// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
bool checkLoopEvaluatable(torch::jit::Node* n) {
bool compile_to_trt = true;
for (auto bn : n->blocks()[0]->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
} else if (bn->kind() == torch::jit::prim::If) {
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
} else {
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
}
}
return compile_to_trt;
}

void finalizeNewBlock(
PartitionedGraph& g,
SegmentedBlock::SegmentedBlockTarget kind,
Expand Down Expand Up @@ -499,20 +502,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
segmented_blocks.back().do_not_merge(true);
continue;
} else if (n->kind() == torch::jit::prim::Loop) {
if (!in_prog_pyt_blk_nodes.empty()) {
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
cur_pyt_nodes_uses.clear();
}
if (checkLoopEvaluatable(n)) {
in_prog_trt_blk_nodes.push_back(n);
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
} else {
auto loop_node = std::vector<torch::jit::Node*>{n};
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
segmented_blocks.back().do_not_merge(true);
}
continue;
}
in_prog_pyt_blk_nodes.push_back(n);
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
Expand Down
3 changes: 3 additions & 0 deletions core/partitioning/segmentedblock/SegmentedBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ struct SegmentedBlock {
return in_types_;
}

BlockID get_id() {
return id_;
}
void update_id(BlockID new_id) {
id_ = new_id;
}
Expand Down
91 changes: 91 additions & 0 deletions tests/core/partitioning/test_resolve_nontensor_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,97 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1);
}

TEST(Partitioning, ResolveMultipleNonTensorInputsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
# TensorRT-intended Block
%16 : int = prim::Constant[value=8]()
%15 : int = prim::Constant[value=64]()
%13 : int = prim::Constant[value=0]()
%10 : int = prim::Constant[value=1]()
%self.linear.bias : Float(4096, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%self.linear.weight : Float(4096, 64, strides=[64, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%3 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=1]()
%x.5 : Tensor = aten::flatten(%x.1, %2, %3)
%4 : Tensor = aten::t(%self.linear.weight)
%6 : Tensor = aten::matmul(%x.5, %4)
%7 : Tensor = trt::const(%self.linear.bias)
%9 : Tensor = aten::add(%7, %6, %10)
%11 : int[] = aten::size(%9) # <string>:13:9
%12 : int = aten::__getitem__(%11, %13)
%shape.3 : int[] = prim::ListConstruct(%12, %15, %16, %16)
%x.13 : Tensor = aten::reshape(%9, %shape.3)

# Torch-intended Block
%num_spatial_dims.2 : int = prim::Constant[value=2]()
%11 : int[] = prim::Constant[value=[0, 0]]()
%10 : bool = prim::Constant[value=0]()
%conv1_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%conv1_weight : Float(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%6 : int = prim::Constant[value=1]()
%5 : int[] = prim::Constant[value=[1, 1]]()
%4 : int[] = prim::Constant[value=[2, 2]]()
%conv_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%conv_weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
%input.16 : Tensor = aten::conv_transpose2d(%x.13, %conv_weight, %conv_bias, %4, %5, %5, %6, %5)
%7 : Tensor = aten::_convolution(%input.16, %conv1_weight, %conv1_bias, %5, %5, %5, %10, %11, %6, %10, %10, %10, %10)
%12 : int[] = aten::size(%7)
%96 : int = aten::len(%12)
%14 : int = aten::__range_length(%num_spatial_dims.2, %96, %6)

# TensorRT-intended Block
%15 : float = prim::Constant[value=1e-05]()
%14 : float = prim::Constant[value=0.1]()
%13 : NoneType = prim::Constant()
%num_spatial_dims.2 : int = prim::Constant[value=2]()
%300 : int = prim::Constant[value=3]()
%345 : int = aten::sub(%300, %96)
%3 : int = aten::add(%345, %6)
%2 : bool = prim::Constant[value=1]()
%size_prods.2 : int = prim::Loop(%3, %2, %6)
block0(%loop : int, %size_prods.13 : int):
%i.3 : int = aten::__derive_index(%loop, %num_spatial_dims.2, %3)
%8 : int = aten::__getitem__(%12, %i.3)
%size_prods.15 : int = aten::mul(%size_prods.13, %8)
-> (%2, %size_prods.15)
%11 : Tensor = aten::instance_norm(%7, %13, %13, %13, %13, %2, %14, %15, %2)
return (%11))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({1, 64}));

torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs_map.insert({g->inputs()[i], {inputs[i]}});
input_types.insert({g->inputs()[i], {{at::kFloat}}});
}

partitioning_info.collection_input_spec_map = inputs_map;
partitioning_info.forced_fallback_operators = {"aten::_convolution"};
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
ctx.input_types_map = input_types;

torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
torch_tensorrt::core::partitioning::partition(&ctx);
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
ctx.partitioned_blocks.begin()->second;

// For each TensorRT segmented block, verify that all inputs are of Tensor type
for (auto block : segmented_blocks) {
if (block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget::kTensorRT) {
for (auto input : block.raw_inputs())
ASSERT_TRUE(input->type()->isSubtypeOf(c10::TensorType::get()));
}
}
}

TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
const auto graph = R"IR(
graph(%0 : Float(1, 3, 16, 16, strides=[768, 256, 16, 1]),
Expand Down