You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docsrc/contributors/partitioning.rst
+148Lines changed: 148 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -91,3 +91,151 @@ To enable automatic fallback feature, you can set following attributes in Python
91
91
cfg.torch_executed_ops.push_back("aten::relu");
92
92
auto trt_mod = torchtrt::ts::compile(mod, cfg);
93
93
auto out = trt_mod.forward({in});
94
+
95
+
Dependency Aware Partitioning
96
+
====================
97
+
During segmentation, Torch-TensorRT uses a dependency graph of the input TorchScript nodes to reduce the number of segments created. Consider this example from test Partitioning.SegmentModelWithDependencyAwareness in `tests/core/partitioning/test_segmentation.cpp <https://github.com/pytorch/TensorRT/blob/master/tests/core/partitioning/test_segmentation.cpp>`_
In this graph `aten::lgamma` is not supported by conversion and must be partitioned in a Torch fallback segment. If Torch-TensorRT uses a greedy segmentation strategy that traverses nodes in the input graph in order and gathers ops with the same target (TensorRT or Torch) into a segment until it encounters an op with a different target, the resulting partition includes 7 segments, many with just a single op.
This partition is valid, but the segmentation is suboptimal. These arithmetic ops and `aten::lgamma` ops are each split into their own segment as we alternate between Torch and TensorRT targets in the linear traversal of the graph.
178
+
179
+
.. code-block:: none
180
+
181
+
%add : Tensor = aten::add(%x, %y, %20)
182
+
%x_lgamma : Tensor = aten::lgamma(%x)
183
+
%mul : Tensor = aten::mul(%x, %y)
184
+
%y_lgamma : Tensor = aten::lgamma(%y)
185
+
%div : Tensor = aten::div(%x, %y)
186
+
%div_lgamma : Tensor = aten::lgamma(%div)
187
+
188
+
Each of the arithmetic ops in this segment is only dependent on constants and the inputs `%x` and `%y`. The `aten::lgamma` ops are dependent on the inputs `%x`, `%y` and the output of the `aten::div`. This means that we could rewrite this portion of the input graph as below without changing the behavior of the graph. This reordered series of ops could be cleanly partitioned into just 2 segments using the greedy segmentation approach described above.
189
+
190
+
.. code-block:: none
191
+
192
+
%add : Tensor = aten::add(%x, %y, %20)
193
+
%mul : Tensor = aten::mul(%x, %y)
194
+
%div : Tensor = aten::div(%x, %y)
195
+
%x_lgamma : Tensor = aten::lgamma(%x)
196
+
%y_lgamma : Tensor = aten::lgamma(%y)
197
+
%div_lgamma : Tensor = aten::lgamma(%div)
198
+
199
+
By adding awareness of the dependencies between ops to the basic greedy segmentation approach we can achieve the same partition without rewriting the graph. Now we will maintain both Torch and TensorRT targeted segments at the same time as we traverse the graph. We will only finalize a segment once we hit an op that is both dependent on an op in the segment and has a different target. This will allow the partition to create larger segments by reordering nodes across the segment boundary while guaranteeing that we will not modify the behavior of the graph by reordering nodes relative to their dependencies.
200
+
In this example we will collect the arithmetic ops in a TensorRT segment and the `aten::lgamma` ops in a Torch segment. When we encounter the `%div_lgamma : Tensor = aten::lgamma(%div)` op we can see it is dependent on `%div : Tensor = aten::div(%x, %y)` in the current TensorRT segment. This triggers finalization of the TensorRT segment containing the `aten::div` op to guarantee it will appear before its dependency in the final partition. The Torch segment containing the `aten::lgamma` op is finalized when we encounter the `prim::ListConstruct` op which targets TensorRT and is dependent on the results of the `aten::lgamma` ops.
In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition.
240
+
The merge segments step identifies a list of segments that are adjacent in the graph, have the same target, and are not marked as `do_not_merge`. The nodes from these segments will be combined into a single new segment that will replace the merged segments in the partition.
241
+
The `do_not_merge` marking is used to prevent merging of segments created for conditional nodes and loops that are handled as special cases in graph stitching and should not be merged with adjacent segments of the same type.
0 commit comments