@@ -29,7 +29,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
29
29
}
30
30
31
31
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
32
- void SetInputsOutputsConnectedNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
32
+ void setInputsOutputsConnectedNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
33
33
// fallback nodes that produce entire graph's nonTensor output
34
34
for (auto i : block->outputs ()) {
35
35
if (!isTensor (i)) {
@@ -50,7 +50,7 @@ void SetInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
50
50
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
51
51
// we use a map to indicate the reason why it's fallback to torch
52
52
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
53
- void SetExplicitFallbackNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
53
+ void setExplicitFallbackNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
54
54
auto nodes = block->nodes ();
55
55
const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
56
56
@@ -78,7 +78,7 @@ void SetExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
78
78
79
79
// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor,
80
80
// then the nodes that produces/consumes those values should also fallback
81
- void SetNonTensorConnectedNodes (PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
81
+ void setNonTensorConnectedNodes (PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
82
82
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
83
83
std::queue<torch::jit::Node*> q;
84
84
for (auto & node : initial_fallback_nodes) {
@@ -112,7 +112,7 @@ void SetNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
112
112
}
113
113
114
114
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
115
- std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize (PartitioningCtx* ctx, torch::jit::Block* block) {
115
+ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize (PartitioningCtx* ctx, torch::jit::Block* block) {
116
116
auto nodes = block->nodes ();
117
117
std::vector<torch::jit::Node*> cur_trt_nodes;
118
118
std::vector<torch::jit::Node*> min_block_fallback_nodes;
@@ -138,19 +138,19 @@ std::vector<torch::jit::Node*> TraverseNodesForMinBlockSize(PartitioningCtx* ctx
138
138
}
139
139
140
140
// Set the nodes that fallback because of min_block_size
141
- void SetMinBlockFallbackNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
141
+ void setMinBlockFallbackNodes (PartitioningCtx* ctx, torch::jit::Block* block) {
142
142
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
143
- auto min_block_fallback_nodes = TraverseNodesForMinBlockSize (ctx, block);
143
+ auto min_block_fallback_nodes = traverseNodesForMinBlockSize (ctx, block);
144
144
145
145
// keep fallback until all segments meet the min_block_size requirement
146
146
while (!min_block_fallback_nodes.empty ()) {
147
147
for (const auto i : min_block_fallback_nodes) {
148
148
ctx->setNodeExecutorDecision (i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK );
149
149
}
150
150
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
151
- SetNonTensorConnectedNodes (ctx, min_block_fallback_nodes);
151
+ setNonTensorConnectedNodes (ctx, min_block_fallback_nodes);
152
152
// keep traverse the graph until there is no node fallback because of min_block_size
153
- min_block_fallback_nodes = TraverseNodesForMinBlockSize (ctx, block);
153
+ min_block_fallback_nodes = traverseNodesForMinBlockSize (ctx, block);
154
154
}
155
155
}
156
156
@@ -173,7 +173,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
173
173
return false ;
174
174
}
175
175
176
- std::vector<torch::jit::Node*> FindModifyingNodes (
176
+ std::vector<torch::jit::Node*> findModifyingNodes (
177
177
torch::jit::Value* val,
178
178
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
179
179
std::vector<torch::jit::Node*> modifying_nodes;
@@ -190,7 +190,7 @@ std::vector<torch::jit::Node*> FindModifyingNodes(
190
190
}
191
191
192
192
// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
193
- std::vector<torch::jit::Node*> GetDependencyNodes (
193
+ std::vector<torch::jit::Node*> getDependencyNodes (
194
194
const std::vector<torch::jit::Value*>& vals,
195
195
const SegmentedBlock& seg_block) {
196
196
// get all nodes in the segmentedblock
@@ -206,7 +206,7 @@ std::vector<torch::jit::Node*> GetDependencyNodes(
206
206
auto node = cur_val->node ();
207
207
if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
208
208
visited.insert (node);
209
- auto modifying_nodes = FindModifyingNodes (cur_val, seg_block_nodes);
209
+ auto modifying_nodes = findModifyingNodes (cur_val, seg_block_nodes);
210
210
stk.insert (stk.end (), modifying_nodes.rbegin (), modifying_nodes.rend ());
211
211
stk.push_back (node);
212
212
for (auto input : node->inputs ()) {
@@ -220,7 +220,7 @@ std::vector<torch::jit::Node*> GetDependencyNodes(
220
220
return stk;
221
221
}
222
222
223
- void ResolveTRTNonTensorInputs (PartitioningCtx* ctx, torch::jit::Block* block) {
223
+ void resolveTRTNonTensorInputs (PartitioningCtx* ctx, torch::jit::Block* block) {
224
224
// if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
225
225
// because we have already found the interface between Torch and TRT in segmentation phase
226
226
// what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs
@@ -235,7 +235,7 @@ void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
235
235
}
236
236
if (!inputs_to_resolve.empty ()) {
237
237
std::vector<torch::jit::Node*> dependency_nodes =
238
- GetDependencyNodes (inputs_to_resolve, cur_partitioned_block[i]);
238
+ getDependencyNodes (inputs_to_resolve, cur_partitioned_block[i]);
239
239
dependency_nodes.insert (
240
240
dependency_nodes.end (),
241
241
cur_partitioned_block[i].raw_nodes ().begin (),
@@ -246,7 +246,7 @@ void ResolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
246
246
}
247
247
}
248
248
249
- void RegisterSegmentsOutputs (PartitioningCtx* ctx, torch::jit::Block* block) {
249
+ void registerSegmentsOutputs (PartitioningCtx* ctx, torch::jit::Block* block) {
250
250
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
251
251
PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks [block];
252
252
auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique () < b->unique (); };
@@ -332,32 +332,32 @@ void finalizeNewBlock(
332
332
LOG_DEBUG (g.back ());
333
333
}
334
334
335
- void SetNodeExecutorLUT (PartitioningCtx* ctx, torch::jit::Block* block) {
335
+ void setNodeExecutorLUT (PartitioningCtx* ctx, torch::jit::Block* block) {
336
336
// First, find all the explicit fallback nodes that should run in Torch:
337
337
// 1. nodes that are unsupported
338
338
// 2. nodes that the user specifies to run in torch
339
339
// 3. nodes that the user specifies the module containing this op to run in torch
340
340
// At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
341
- SetExplicitFallbackNodes (ctx, block);
341
+ setExplicitFallbackNodes (ctx, block);
342
342
343
343
// Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
344
344
// consume/produce this nonTensor value
345
- SetInputsOutputsConnectedNodes (ctx, block);
345
+ setInputsOutputsConnectedNodes (ctx, block);
346
346
347
347
// Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
348
348
// input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
349
349
// that consume this output should also fallback
350
350
auto cur_fallback_nodes = ctx->getNodesRunInTorch ();
351
- SetNonTensorConnectedNodes (ctx, cur_fallback_nodes);
351
+ setNonTensorConnectedNodes (ctx, cur_fallback_nodes);
352
352
353
353
// Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
354
354
// We need to traverse the whole graph many times here
355
- SetMinBlockFallbackNodes (ctx, block);
355
+ setMinBlockFallbackNodes (ctx, block);
356
356
}
357
357
358
- void SegmentGraph (PartitioningCtx* ctx, torch::jit::Block* block) {
358
+ void segmentGraph (PartitioningCtx* ctx, torch::jit::Block* block) {
359
359
// Find all the fallback nodes and build execution decision LUT for all nodes
360
- SetNodeExecutorLUT (ctx, block);
360
+ setNodeExecutorLUT (ctx, block);
361
361
362
362
auto nodes = block->nodes ();
363
363
@@ -436,24 +436,24 @@ void SegmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
436
436
return ;
437
437
}
438
438
439
- void Partition (PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
439
+ void partition (PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
440
440
LOG_DEBUG (ctx->settings );
441
441
442
442
// Go through all the blocks to do the partitioning
443
443
for (torch::jit::Block* block : ctx->original_blocks ) {
444
444
// segment lowering global graph into blocks
445
- SegmentGraph (ctx, block);
445
+ segmentGraph (ctx, block);
446
446
447
447
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
448
448
// resolve nonTensor inputs/outputs
449
- ResolveTRTNonTensorInputs (ctx, block);
449
+ resolveTRTNonTensorInputs (ctx, block);
450
450
451
451
// register input/output torch::jit::Value for segmented graphs
452
452
LOG_DEBUG (" Registering input/output torch::jit::Value for segmented graphs" );
453
- RegisterSegmentsOutputs (ctx, block);
453
+ registerSegmentsOutputs (ctx, block);
454
454
455
455
// run shape analysis on each segmented block
456
- RunShapeAnalysis (ctx, block, example_tensor_map);
456
+ runShapeAnalysis (ctx, block, example_tensor_map);
457
457
}
458
458
}
459
459
0 commit comments