Skip to content

chore: dynamic shape support for pdist ops #3068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3569,7 +3569,9 @@ def aten_ops_any(
)


@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
@dynamo_tensorrt_converter(
torch.ops.aten._pdist_forward.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
208 changes: 199 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
create_constant,
get_axes_for_reduce_op,
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

Expand Down Expand Up @@ -417,20 +420,21 @@ def pdist(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
shape = input.shape
# Extend input from shape [N, D] to [N, 1, D]
extend_input = impl.shuffle.reshape(
extend_input = impl.unsqueeze.unsqueeze(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same functionality. it's changed to handle when all shapes are dynamic shape

ctx,
target,
source_ir,
f"{name}_reshape",
f"{name}_unsqueeze",
input,
shape=shape[0:1] + (1,) + shape[1:],
1,
)

# Expand the input from [N, 1, D] to [N, N, D]
x = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_sub",
f"{name}_expand",
extend_input,
(shape[0], shape[0]) + shape[1:],
)
Expand Down Expand Up @@ -482,8 +486,194 @@ def pdist(
raise RuntimeError(
f"p should between [0, inf], currently p={p} is not supported!"
)
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
if shape[0] == DYNAMIC_DIM:
dim = get_shape(ctx, target, source_ir, f"{name}_get_shape", input, 0)
shuffle_layer = ctx.net.add_shuffle(dim)
shuffle_layer.reshape_dims = trt.Dims()
set_layer_name(shuffle_layer, target, f"{name}_shuffle", source_ir)
dim_tensor = shuffle_layer.get_output(0)
indices_tensor = tri_upper_indices(
ctx, target, source_ir, f"{name}_triu_indices", dim_tensor
)
gather_layer = ctx.net.add_gather_v2(
norm, indices_tensor, mode=trt.GatherMode.ND
)
set_layer_name(gather_layer, target, f"{name}_gather_layer", source_ir)
gather_layer.axis = 2
return gather_layer.get_output(0)
else:
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)


def tri_upper_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
size_tensor: TRTTensor,
) -> TRTTensor:
"""
Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor,
where the diagonal offset = 1. One loop is used to calculate the indices like below.
x = 0, y = 0, y_start = 1
out_size = size * (size - 1) // 2
for _ in range(out_size):
y_out.append(y_start + y)
x_out.append(x)
y += 1
if (y_start + y) >= size:
x += 1
y_start += 1
y = 0
Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network.
target (Target): Target of calling node.
source_ir (Optional[SourceIR]): SourceIR of calling converter.
name (str): Name of the calling layer.
size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor.

Example:
if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
"""
constant_0 = create_constant(ctx, 0, f"{name}_zero", np.int32, 0)
constant_1 = create_constant(ctx, 1, f"{name}_one", np.int32, 0)
constant_2 = create_constant(ctx, 2, f"{name}_two", np.int32, 0)

size_minus_one = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_size_minus_one", size_tensor, constant_1
)

size_mult_prev = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_size_mult_prev", size_tensor, size_minus_one
)

num_loop = impl.elementwise.floor_divide(
ctx, target, source_ir, f"{name}_num_loop", size_mult_prev, constant_2
)

loop = ctx.net.add_loop()
loop.add_trip_limit(num_loop, trt.TripLimit.COUNT)

x_recurrence = loop.add_recurrence(constant_0)
set_layer_name(x_recurrence, target, f"{name}_x_recurrence", source_ir)
x_tensor = x_recurrence.get_output(0)

y_recurrence = loop.add_recurrence(constant_0)
set_layer_name(y_recurrence, target, f"{name}_y_recurrence", source_ir)
y_tensor = y_recurrence.get_output(0)

y_start_recurrence = loop.add_recurrence(constant_1)
set_layer_name(y_start_recurrence, target, f"{name}_y_start_recurrence", source_ir)
y_start_tensor = y_start_recurrence.get_output(0)

x_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_x_inc",
x_tensor,
constant_1,
)

y_current = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_current",
y_start_tensor,
y_tensor,
)

y_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_inc",
y_tensor,
constant_1,
)

next_y = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_next_y",
y_start_tensor,
y_inc,
)

y_start_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_start_inc",
y_start_tensor,
constant_1,
)
cond = ge(ctx, target, source_ir, f"{name}_cond", next_y, size_tensor)
x_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_x_output",
x_inc,
x_tensor,
cond,
)
x_recurrence.set_input(1, x_output)

y_start_current = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_y_start_current",
y_start_inc,
y_start_tensor,
cond,
)
y_start_recurrence.set_input(1, y_start_current)

y_val = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_y_val",
constant_0,
y_inc,
cond,
)
y_recurrence.set_input(1, y_val)

loop_output_x = loop.add_loop_output(x_tensor, trt.LoopOutput.CONCATENATE)
loop_output_y = loop.add_loop_output(y_current, trt.LoopOutput.CONCATENATE)
loop_output_x.set_input(1, num_loop)
loop_output_y.set_input(1, num_loop)

# Cat two N tensors into 2 x N. [0, 0, 0], [1, 2, 3] -> [[0, 0, 0], [1, 2, 3]]
x_index = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_x_index", loop_output_x.get_output(0), (1, -1)
)
y_index = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_y_index", loop_output_y.get_output(0), (1, -1)
)

x_y_tensor = cat(
ctx,
target,
source_ir,
f"{name}_x_y_tensor",
[x_index, y_index],
0,
)

# Reshape 2 x N output to N x 2. [[0, 0, 0], [1, 2, 3]] -> [[0, 1], [0, 2], [0, 3]]
indices_tensor = ctx.net.add_shuffle(x_y_tensor)
set_layer_name(indices_tensor, target, f"{name}_indices_tensor", source_ir)
indices_tensor.first_transpose = trt.Permutation([1, 0])
indices_tensor.reshape_dims = (-1, 2)

return indices_tensor.get_output(0)


def cdist_forward(
Expand Down
58 changes: 58 additions & 0 deletions tests/py/dynamo/conversion/test_pdist_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,5 +33,62 @@ def forward(self, input):
)


class TestDynamicShapePdistConverter(DispatchTestCase):
@parameterized.expand(
[
(
"dim0_dynamic_dim1_static_p_0",
(1, 4),
(2, 4),
(4, 4),
0,
),
(
"dim0_static_dim1_dynamic_p_1",
(3, 1),
(3, 2),
(3, 4),
1,
),
(
"dim0_dynamic_dim1_static_p_other",
(1, 5),
(2, 5),
(6, 5),
0.4,
),
(
"dim0_dynamic_dim1_dynamic_p_inf",
(1, 1),
(2, 2),
(5, 4),
float("inf"),
),
(
"dim0_dynamic_dim1_dynamic_p_other",
(2, 1),
(3, 2),
(4, 7),
1.7,
),
]
)
def test_pdist_float(self, _, min_shape, opt_shape, max_shape, p):
class Pdist(nn.Module):
def forward(self, input):
return torch.ops.aten._pdist_forward.default(input, p)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float,
),
]

self.run_test_with_dynamic_shape(Pdist(), input_specs)


if __name__ == "__main__":
run_tests()
Loading