Skip to content

Commit d053b4d

Browse files
committed
chore: address review comments
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6db9ffb commit d053b4d

File tree

8 files changed

+67
-67
lines changed

8 files changed

+67
-67
lines changed

core/compiler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138

139139
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140140
auto collection_input_ivalues_map =
141-
partitioning::GenerateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
141+
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142142

143-
partitioning::Partition(&partitioning_ctx, collection_input_ivalues_map);
143+
partitioning::partition(&partitioning_ctx, collection_input_ivalues_map);
144144

145145
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
146146
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
@@ -174,7 +174,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
174174
}
175175
}
176176

177-
return partitioning::Stitch(&partitioning_ctx, block);
177+
return partitioning::stitch(&partitioning_ctx, block);
178178
}
179179

180180
void MapInputsAndDetermineDTypes(

core/partitioning/partitioning.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
2929
}
3030

3131
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
32-
void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
32+
void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
3333
// fallback nodes that produce entire graph's nonTensor output
3434
for (auto i : block->outputs()) {
3535
if (!isTensor(i)) {
@@ -50,7 +50,7 @@ void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
5050
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
5151
// we use a map to indicate the reason why it's fallback to torch
5252
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
53-
void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
53+
void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
5454
auto nodes = block->nodes();
5555
const auto to_compile_sym = c10::Symbol::attr("to_compile");
5656

@@ -78,7 +78,7 @@ void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
7878

7979
// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor,
8080
// then the nodes that produces/consumes those values should also fallback
81-
void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
81+
void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
8282
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
8383
std::queue<torch::jit::Node*> q;
8484
for (auto& node : initial_fallback_nodes) {
@@ -112,7 +112,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
112112
}
113113

114114
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
115-
std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
115+
std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
116116
auto nodes = block->nodes();
117117
std::vector<torch::jit::Node*> cur_trt_nodes;
118118
std::vector<torch::jit::Node*> min_block_fallback_nodes;
@@ -138,19 +138,19 @@ std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx
138138
}
139139

140140
// Set the nodes that fallback because of min_block_size
141-
void SetMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
141+
void setMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
142142
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
143-
auto min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block);
143+
auto min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block);
144144

145145
// keep fallback until all segments meet the min_block_size requirement
146146
while (!min_block_fallback_nodes.empty()) {
147147
for (const auto i : min_block_fallback_nodes) {
148148
ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK);
149149
}
150150
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
151-
SetNonTensorConnectedNodes(ctx, min_block_fallback_nodes);
151+
setNonTensorConnectedNodes(ctx, min_block_fallback_nodes);
152152
// keep traverse the graph until there is no node fallback because of min_block_size
153-
min_block_fallback_nodes = TraverseNodesForMinBlockSize(ctx, block);
153+
min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block);
154154
}
155155
}
156156

@@ -173,7 +173,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
173173
return false;
174174
}
175175

176-
std::vector<torch::jit::Node*> FindModifyingNodes(
176+
std::vector<torch::jit::Node*> findModifyingNodes(
177177
torch::jit::Value* val,
178178
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
179179
std::vector<torch::jit::Node*> modifying_nodes;
@@ -190,7 +190,7 @@ std::vector<torch::jit::Node*> FindModifyingNodes(
190190
}
191191

192192
// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
193-
std::vector<torch::jit::Node*> GetDependencyNodes(
193+
std::vector<torch::jit::Node*> getDependencyNodes(
194194
const std::vector<torch::jit::Value*>& vals,
195195
const SegmentedBlock& seg_block) {
196196
// get all nodes in the segmentedblock
@@ -206,7 +206,7 @@ std::vector<torch::jit::Node*> GetDependencyNodes(
206206
auto node = cur_val->node();
207207
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
208208
visited.insert(node);
209-
auto modifying_nodes = FindModifyingNodes(cur_val, seg_block_nodes);
209+
auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes);
210210
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
211211
stk.push_back(node);
212212
for (auto input : node->inputs()) {
@@ -220,7 +220,7 @@ std::vector<torch::jit::Node*> GetDependencyNodes(
220220
return stk;
221221
}
222222

223-
void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
223+
void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
224224
// if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
225225
// because we have already found the interface between Torch and TRT in segmentation phase
226226
// what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs
@@ -235,7 +235,7 @@ void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
235235
}
236236
if (!inputs_to_resolve.empty()) {
237237
std::vector<torch::jit::Node*> dependency_nodes =
238-
GetDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]);
238+
getDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]);
239239
dependency_nodes.insert(
240240
dependency_nodes.end(),
241241
cur_partitioned_block[i].raw_nodes().begin(),
@@ -246,7 +246,7 @@ void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
246246
}
247247
}
248248

249-
void RegisterSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
249+
void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
250250
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
251251
PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block];
252252
auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique() < b->unique(); };
@@ -332,32 +332,32 @@ void finalizeNewBlock(
332332
LOG_DEBUG(g.back());
333333
}
334334

335-
void SetNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
335+
void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
336336
// First, find all the explicit fallback nodes that should run in Torch:
337337
// 1. nodes that are unsupported
338338
// 2. nodes that the user specifies to run in torch
339339
// 3. nodes that the user specifies the module containing this op to run in torch
340340
// At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
341-
SetExplicitFallbackNodes(ctx, block);
341+
setExplicitFallbackNodes(ctx, block);
342342

343343
// Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
344344
// consume/produce this nonTensor value
345-
SetInputsOutputsConnectedNodes(ctx, block);
345+
setInputsOutputsConnectedNodes(ctx, block);
346346

347347
// Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
348348
// input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
349349
// that consume this output should also fallback
350350
auto cur_fallback_nodes = ctx->getNodesRunInTorch();
351-
SetNonTensorConnectedNodes(ctx, cur_fallback_nodes);
351+
setNonTensorConnectedNodes(ctx, cur_fallback_nodes);
352352

353353
// Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
354354
// We need to traverse the whole graph many times here
355-
SetMinBlockFallbackNodes(ctx, block);
355+
setMinBlockFallbackNodes(ctx, block);
356356
}
357357

358-
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
358+
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
359359
// Find all the fallback nodes and build execution decision LUT for all nodes
360-
SetNodeExecutorLUT(ctx, block);
360+
setNodeExecutorLUT(ctx, block);
361361

362362
auto nodes = block->nodes();
363363

@@ -436,24 +436,24 @@ void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
436436
return;
437437
}
438438

439-
void Partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
439+
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
440440
LOG_DEBUG(ctx->settings);
441441

442442
// Go through all the blocks to do the partitioning
443443
for (torch::jit::Block* block : ctx->original_blocks) {
444444
// segment lowering global graph into blocks
445-
SegmentGraph(ctx, block);
445+
segmentGraph(ctx, block);
446446

447447
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
448448
// resolve nonTensor inputs/outputs
449-
ResolveTRTNonTensorInputs(ctx, block);
449+
resolveTRTNonTensorInputs(ctx, block);
450450

451451
// register input/output torch::jit::Value for segmented graphs
452452
LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
453-
RegisterSegmentsOutputs(ctx, block);
453+
registerSegmentsOutputs(ctx, block);
454454

455455
// run shape analysis on each segmented block
456-
RunShapeAnalysis(ctx, block, example_tensor_map);
456+
runShapeAnalysis(ctx, block, example_tensor_map);
457457
}
458458
}
459459

core/partitioning/partitioning.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> Example
1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21-
ExampleIValues GenerateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
21+
ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
2222

23-
void RunShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
23+
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
2424

25-
void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
25+
void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
2626

27-
GraphAndMapping Stitch(PartitioningCtx* ctx, torch::jit::Block* block);
27+
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
2828

29-
void Partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map);
29+
void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map);
3030

3131
} // namespace partitioning
3232
} // namespace core

core/partitioning/shape_analysis.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace torch_tensorrt {
99
namespace core {
1010
namespace partitioning {
1111

12-
at::Tensor GenerateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
12+
at::Tensor generateSingleInput(ir::Input& input, c10::optional<at::ScalarType>& type_opt) {
1313
auto cur_shape = input.input_shape;
1414
std::vector<int64_t> shape;
1515
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
@@ -25,7 +25,7 @@ at::Tensor GenerateSingleInput(ir::Input& input, c10::optional<at::ScalarType>&
2525
return in;
2626
}
2727

28-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> GenerateRandomInputs(
28+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
2929
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& inputs,
3030
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& types) {
3131
// generate random inputs for running pytorch segments
@@ -38,28 +38,28 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> GenerateRandomI
3838
c10::TypePtr elementType = c10::TensorType::get();
3939
auto generic_list = c10::impl::GenericList(elementType);
4040
for (size_t i = 0; i < input.second.size(); i++) {
41-
auto in = GenerateSingleInput(input.second[i], types[input.first][i]);
41+
auto in = generateSingleInput(input.second[i], types[input.first][i]);
4242
generic_list.push_back(in.clone());
4343
}
4444
ivalue_map[input.first] = c10::IValue(generic_list);
4545
} else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) {
4646
// create tuple
4747
std::vector<torch::jit::IValue> list;
4848
for (size_t i = 0; i < input.second.size(); i++) {
49-
auto in = GenerateSingleInput(input.second[i], types[input.first][i]);
49+
auto in = generateSingleInput(input.second[i], types[input.first][i]);
5050
list.push_back(in.clone());
5151
}
5252
auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr
5353
ivalue_map[input.first] = c10::IValue(tuple);
5454
} else {
55-
auto in = GenerateSingleInput(input.second[0], types[input.first][0]);
55+
auto in = generateSingleInput(input.second[0], types[input.first][0]);
5656
ivalue_map[input.first] = in.clone();
5757
}
5858
}
5959
return ivalue_map;
6060
}
6161

62-
void GetSegmentsOutputByRunning(
62+
void getSegmentsOutputByRunning(
6363
SegmentedBlock& seg_block,
6464
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
6565
const PartitioningInfo& partitioning_info) {
@@ -181,11 +181,11 @@ void GetSegmentsOutputByRunning(
181181
seg_block.register_intypes(input_types);
182182
}
183183

184-
void RunShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
184+
void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
185185
// register every segment's input shape, and it's running output IValues
186186
for (auto& seg_block : ctx->partitioned_blocks[block]) {
187187
torch::jit::ConstantPooling(seg_block.g());
188-
GetSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
188+
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
189189
}
190190
return;
191191
}

core/partitioning/stitching.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace torch_tensorrt {
99
namespace core {
1010
namespace partitioning {
1111

12-
void AddSegmentedBlockToGraph(
12+
void addSegmentedBlockToGraph(
1313
std::shared_ptr<torch::jit::Graph>& g,
1414
partitioning::SegmentedBlock& seg,
1515
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
@@ -49,7 +49,7 @@ void AddSegmentedBlockToGraph(
4949
return;
5050
}
5151

52-
void AddIfBlockToGraph(
52+
void addIfBlockToGraph(
5353
std::shared_ptr<torch::jit::Graph>& new_g,
5454
torch::jit::Node* if_node,
5555
const std::vector<GraphAndMapping>& graph_and_mappings,
@@ -97,7 +97,7 @@ void AddIfBlockToGraph(
9797
return;
9898
}
9999

100-
GraphAndMapping Stitch(PartitioningCtx* ctx, torch::jit::Block* block) {
100+
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block) {
101101
auto new_g = std::make_shared<torch::jit::Graph>();
102102

103103
// the mapping from lowering graph => fallback global graph
@@ -109,20 +109,20 @@ GraphAndMapping Stitch(PartitioningCtx* ctx, torch::jit::Block* block) {
109109
for (auto seg_block : ctx->partitioned_blocks[block]) {
110110
LOG_INFO("Block segment:" << seg_block);
111111
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
112-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
112+
addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
113113
} else {
114114
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
115115
auto if_node = seg_block.raw_nodes()[0];
116116

117117
// convert the 2 blocks in prim::if and get the converted graph with mappings
118118
std::vector<GraphAndMapping> graph_and_mappings;
119119
for (auto cur_block : if_node->blocks()) {
120-
graph_and_mappings.push_back(Stitch(ctx, cur_block));
120+
graph_and_mappings.push_back(stitch(ctx, cur_block));
121121
}
122-
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
122+
addIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
123123

124124
} else {
125-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
125+
addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
126126
}
127127
}
128128
}

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
122122
inputs_map.insert({g->inputs()[i], {inputs[i]}});
123123
input_types.insert({g->inputs()[i], {{at::kFloat}}});
124124
}
125-
auto input_ivalues_map = torch_tensorrt::core::partitioning::GenerateRandomInputs(inputs_map, input_types);
125+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
126126
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
127-
torch_tensorrt::core::partitioning::Partition(&ctx, input_ivalues_map);
127+
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
128128
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
129129
ctx.partitioned_blocks.begin()->second;
130130

@@ -182,10 +182,10 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
182182
inputs_map.insert({g->inputs()[i], {inputs[i]}});
183183
input_types.insert({g->inputs()[i], {{at::kFloat}}});
184184
}
185-
auto input_ivalues_map = torch_tensorrt::core::partitioning::GenerateRandomInputs(inputs_map, input_types);
185+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
186186
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
187187

188-
torch_tensorrt::core::partitioning::Partition(&ctx, input_ivalues_map);
188+
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
189189
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
190190
ctx.partitioned_blocks.begin()->second;
191191

@@ -376,9 +376,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
376376
inputs_map.insert({g->inputs()[i], {inputs[i]}});
377377
input_types.insert({g->inputs()[i], {{at::kFloat}}});
378378
}
379-
auto input_ivalues_map = torch_tensorrt::core::partitioning::GenerateRandomInputs(inputs_map, input_types);
379+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
380380
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
381-
torch_tensorrt::core::partitioning::Partition(&ctx, input_ivalues_map);
381+
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
382382
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
383383

384384
int torch_block_cnt = 0, trt_block_cnt = 0;

0 commit comments

Comments
 (0)