Skip to content

Commit 24172f0

Browse files
authored
Merge pull request #1263 from pytorch/partitioning_ctx
Centralizing Partitioning State
2 parents 8498287 + d053b4d commit 24172f0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1076
-646
lines changed

core/compiler.cpp

Lines changed: 46 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
14-
#include "torch/csrc/jit/ir/ir_views.h"
1514
#include "torch/csrc/jit/passes/graph_fuser.h"
1615
#include "torch/csrc/jit/passes/loop_unrolling.h"
1716
#include "torch/csrc/jit/passes/lower_graph.h"
@@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128127
return conversion::VerifyConverterSupportForBlock(g->block());
129128
}
130129

131-
void AddSegmentedBlockToGraph(
132-
std::shared_ptr<torch::jit::Graph>& g,
133-
partitioning::SegmentedBlock& seg,
134-
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
135-
// old_to_new_g contains: original global graph value => new global graph value,
136-
// mini_to_new_g: mini graph value -> new graph value
137-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
138-
size_t input_idx = 0;
139-
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
140-
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
141-
auto self = g->insertInput(0, "self_1");
142-
self->setType(seg.inputs()[0]->type());
143-
}
144-
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
145-
}
146-
147-
for (auto& raw_input : seg.raw_inputs()) {
148-
if (old_to_new_g.count(raw_input)) {
149-
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
150-
}
151-
}
152-
153-
for (const auto n : seg.nodes()) {
154-
util::cloneNode(n, g, mini_to_new_g);
155-
}
156-
157-
// original graph value => new global graph value
158-
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
159-
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
160-
}
161-
size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
162-
for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
163-
if (!old_to_new_g.count(seg.raw_inputs()[i])) {
164-
old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
165-
}
166-
}
167-
168-
return;
169-
}
170-
171-
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
172-
GraphAndMapping;
173-
174-
void AddIfBlockToGraph(
175-
std::shared_ptr<torch::jit::Graph>& new_g,
176-
torch::jit::Node* if_node,
177-
const std::vector<GraphAndMapping>& graph_and_mappings,
178-
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
179-
torch::jit::IfView if_view(if_node);
180-
181-
// create a new if node in new_g and add corresponding inputs
182-
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
183-
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
184-
185-
// iterate over all blocks and add them to new created prim::If
186-
for (auto graph_and_mapping : graph_and_mappings) {
187-
auto new_if_block = new_if->addBlock();
188-
auto cur_block_graph = graph_and_mapping.first;
189-
auto cur_block_mapping = graph_and_mapping.second;
190-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
191-
for (auto& i : cur_block_mapping) {
192-
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
193-
// it's mini graph's input
194-
if (old_to_new_g.count(i.first)) {
195-
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
196-
}
197-
}
198-
199-
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
200-
new_if_block->cloneFrom(cur_block_graph->block(), env);
201-
if (cur_block_graph->inputs().size() &&
202-
cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
203-
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
204-
auto self = new_g->insertInput(0, "self_1");
205-
self->setType(cur_block_graph->inputs()[0]->type());
206-
}
207-
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
208-
}
209-
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
210-
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
211-
new_if_block->eraseInput(i);
212-
}
213-
}
214-
for (auto ov : if_view.outputs()) {
215-
auto no = new_if->addOutput();
216-
old_to_new_g[ov] = no;
217-
no->copyMetadata(ov);
218-
}
219-
return;
220-
}
221-
222-
GraphAndMapping ConstructFallbackGraph(
130+
partitioning::GraphAndMapping BuildHybridGraph(
223131
torch::jit::script::Module& new_mod,
224132
torch::jit::Block* block,
225-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
226133
CompileSpec cfg,
227134
ir::StaticParams static_params,
228-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
229-
auto convert_cfg = cfg.convert_info;
230-
auto partition_info = cfg.partition_info;
231-
232-
auto new_g = std::make_shared<torch::jit::Graph>();
233-
234-
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);
235-
236-
// the mapping from lowering graph => fallback global graph
237-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
238-
for (auto input : block->inputs()) {
239-
util::getOrAddInputForValue(input, new_g, old_to_new_g);
240-
}
241-
242-
for (auto& seg_block : segmented_blocks) {
243-
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
244-
std::ostringstream trt_engine_id;
245-
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
246-
247-
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
248-
auto shapes = seg_block.in_shapes();
249-
auto types = seg_block.in_types();
250-
std::vector<ir::Input> inputs;
251-
for (size_t i = 0; i < shapes.size(); i++) {
252-
auto in = ir::Input(shapes[i]);
253-
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
254-
inputs.push_back(in);
255-
}
256-
// update the input ranges for each segments
257-
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
258-
259-
// TODO mapping Inputs Ivalue to flatten one here
260-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
261-
auto temp_g = std::make_shared<torch::jit::Graph>();
262-
auto device_spec = convert_cfg.engine_settings.device;
263-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
264-
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
265-
266-
seg_block.update_graph(temp_g);
267-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
268-
} else {
269-
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
270-
auto if_node = seg_block.raw_nodes()[0];
271-
272-
// convert the 2 blocks in prim::if and get the converted graph with mappings
273-
std::vector<GraphAndMapping> graph_and_mappings;
274-
for (auto cur_block : if_node->blocks()) {
275-
graph_and_mappings.push_back(
276-
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
135+
ir::CollectionTypeMap first_use_types) {
136+
auto convert_info = cfg.convert_info;
137+
auto partitioning_info = cfg.partitioning_info;
138+
139+
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140+
auto collection_input_ivalues_map =
141+
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
142+
143+
partitioning::partition(&partitioning_ctx, collection_input_ivalues_map);
144+
145+
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
146+
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
147+
148+
for (auto& seg_block : segmented_blocks) {
149+
LOG_INFO("Block segment:" << seg_block);
150+
std::ostringstream trt_engine_id;
151+
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
152+
153+
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
154+
auto shapes = seg_block.in_shapes();
155+
auto types = seg_block.in_types();
156+
std::vector<ir::Input> inputs;
157+
for (size_t i = 0; i < shapes.size(); i++) {
158+
auto in = ir::Input(shapes[i]);
159+
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160+
inputs.push_back(in);
277161
}
278-
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
162+
// update the input ranges for each segments
163+
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
279164

280-
} else {
281-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
282-
}
283-
}
284-
}
165+
// TODO mapping Inputs Ivalue to flatten one here
166+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
167+
auto temp_g = std::make_shared<torch::jit::Graph>();
168+
auto device_spec = convert_info.engine_settings.device;
169+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
170+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
285171

286-
if (block->outputs().size() > 1) {
287-
std::vector<torch::jit::Value*> fallback_graph_vector;
288-
for (auto& output : block->outputs()) {
289-
if (old_to_new_g.count(output)) {
290-
fallback_graph_vector.push_back(old_to_new_g[output]);
172+
seg_block.update_graph(temp_g);
291173
}
292174
}
293-
torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
294-
auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
295-
new_g->block()->appendNode(return_tuple_node);
296-
// Set the output as the produced tuple
297-
new_g->registerOutput(return_tuple_node->outputs()[0]);
298-
} else {
299-
if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
300-
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
301-
}
302175
}
303-
return {new_g, old_to_new_g};
176+
177+
return partitioning::stitch(&partitioning_ctx, block);
304178
}
305179

306180
void MapInputsAndDetermineDTypes(
@@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes(
310184
ir::CollectionTypeMap& first_use_type_map) {
311185
cfg.convert_info.collection_input_spec_map =
312186
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
187+
cfg.partitioning_info.collection_input_spec_map =
188+
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
313189

314190
auto collection_inputs = ir::get_collection_inputs(g, static_params);
315191
LOG_DEBUG(
@@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes(
339215
"Cannot infer input type from calcuations in graph for input "
340216
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
341217
spec[i].dtype = nvinfer1::DataType::kFLOAT;
342-
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
218+
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
343219
if (!est_type_opt[i]) {
344220
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
345221
std::stringstream ss;
@@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
424300
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
425301
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
426302
auto outputIsCollection = conversion::OutputIsCollection(g->block());
427-
if (cfg.partition_info.enabled &&
303+
if (cfg.partitioning_info.enabled &&
428304
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
429-
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
305+
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
430306
!outputIsCollection) {
431307
LOG_INFO("Skipping partitioning since model is fully supported");
432308
}
433309

434-
if (cfg.partition_info.enabled &&
310+
if (cfg.partitioning_info.enabled &&
435311
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
436-
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
312+
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
437313
outputIsCollection)) {
438-
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
439-
auto collection_input_ivalues_map =
440-
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
441-
auto graph_and_mapping = ConstructFallbackGraph(
442-
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
314+
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
443315
new_g = graph_and_mapping.first;
444316
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
445317
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct CompileSpec {
1919
ir::GraphInputs graph_inputs;
2020
conversion::ConversionInfo convert_info;
2121
lowering::LowerInfo lower_info;
22-
partitioning::PartitionInfo partition_info;
22+
partitioning::PartitioningInfo partitioning_info;
2323
};
2424

2525
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4141
passes::MarkNodesForFallback(g, true);
4242
}
4343
passes::UnpackHardSwish(g);
44+
passes::UnpackHardSigmoid(g);
4445
passes::EliminateExceptionOrPassPattern(g);
4546
passes::ReduceToOperation(g);
4647
passes::ReduceGelu(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cc_library(
3030
"silu_to_sigmoid_multiplication.cpp",
3131
"unpack_addmm.cpp",
3232
"unpack_batch_norm.cpp",
33+
"unpack_hardsigmoid.cpp",
3334
"unpack_hardswish.cpp",
3435
"unpack_log_softmax.cpp",
3536
"unpack_std.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ target_sources(${lib_name}
1717
"${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp"
1818
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp"
1919
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"
20+
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
2021
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
2122
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
2223
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3838
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
3939
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4040
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
41+
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
4142

4243
} // namespace passes
4344
} // namespace lowering
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string hardsigmoid_pattern = R"IR(
12+
graph(%input):
13+
%result = aten::hardsigmoid(%input)
14+
return (%result))IR";
15+
16+
std::string hardsigmoid_pattern_inplace = R"IR(
17+
graph(%input):
18+
%result = aten::hardsigmoid_(%input)
19+
return (%result))IR";
20+
21+
std::string new_pattern = R"IR(
22+
graph(%x.1):
23+
%22 : float = prim::Constant[value=0.5]()
24+
%3 : int = prim::Constant[value=6]()
25+
%5 : int = prim::Constant[value=1]()
26+
%10 : int = prim::Constant[value=0]()
27+
%4 : Tensor = aten::div(%x.1, %3)
28+
%9 : Tensor = aten::add(%4, %22, %5)
29+
%21 : Tensor = aten::clamp(%9, %10, %5)
30+
return (%21))IR";
31+
32+
torch::jit::SubgraphRewriter rewriter;
33+
rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern);
34+
rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern);
35+
rewriter.runOnGraph(graph);
36+
37+
LOG_GRAPH("Post unpack hardsigmoid: " << *graph);
38+
}
39+
40+
} // namespace passes
41+
} // namespace lowering
42+
} // namespace core
43+
} // namespace torch_tensorrt

core/partitioning/BUILD

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@ config_setting(
1313
cc_library(
1414
name = "partitioning",
1515
srcs = [
16-
"PartitionInfo.cpp",
17-
"SegmentedBlock.cpp",
1816
"partitioning.cpp",
1917
"shape_analysis.cpp",
18+
"stitching.cpp",
2019
],
2120
hdrs = [
22-
"PartitionInfo.h",
23-
"SegmentedBlock.h",
2421
"partitioning.h",
25-
"shape_analysis.h",
2622
],
2723
deps = [
2824
"//core/util:prelude",
2925
"//core/ir",
3026
"//core/conversion",
3127
"//core/lowering",
28+
"//core/partitioning/partitioningctx",
29+
"//core/partitioning/partitioninginfo",
30+
"//core/partitioning/segmentedblock",
3231
] + select({
3332
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
3433
"//conditions:default": ["@libtorch//:libtorch"],
@@ -39,10 +38,7 @@ cc_library(
3938
pkg_tar(
4039
name = "include",
4140
srcs = [
42-
"PartitionInfo.h",
43-
"SegmentedBlock.h",
4441
"partitioning.h",
45-
"shape_analysis.h",
4642
],
4743
package_dir = "core/partitioning/",
4844
)

0 commit comments

Comments
 (0)