Skip to content

chore: bug fixes for full and expand #3019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ def full(
) -> TRTTensor:
# in static shape scenario, shape is a list of int
if isinstance(shape, List):
return np.full(shape, fill_value)
# in static shape scenario, shape is a list of int
if all(isinstance(dim, int) for dim in shape):
return np.full(shape, fill_value)
else:
shape = impl.cat.cat(
ctx, target, source_ir, name + "_concat_shape", shape, 0
)

# in dynamic shape scenario, shape is a shap tensor
# use IFillLayer to fill the shape tensor with LINSPACE value
Expand Down
57 changes: 35 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def expand(
) -> TRTTensor:
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
Expand All @@ -244,39 +245,49 @@ def expand(
# After the above padding, the shape and tensor rank must be equal
assert len(input_t.shape) == shape_rank

shape_t = []
for i in range(shape_rank):
if shape[i] == -1:
shape_t.append(
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
)
else:
shape_t.append(shape[i])

# Establish the desired output shape, strides, and starting indices
input_tensor_shape = tuple(input_t.shape)
# Configure the start, strides and output shape tensors
start = tuple([0] * shape_rank)

# TODO: Revisit stride calculation. stride[dim]=0 implies that dimension is being broadcasted.
# stride[dim]=0 implies that dimension is being broadcasted.
# stride should be 1 for all non-broadcasted dims
stride = []
for i, o in zip(input_tensor_shape, shape_t):
# If the shape has ITensor, we treat it as a reshape dim instead of a broadcasted dim
# shape_t cannot have -1. If the input at this dimension has a shape of -1, set the stride to 1. This indicates that the input is dynamic and does not imply broadcasting at that specific dimension.
if isinstance(i, int) and isinstance(o, int) and i != DYNAMIC_DIM:
input_tensor_shape = tuple(input_t.shape)
for i, o in zip(input_tensor_shape, shape):
# If input dim and target shape dim are static, broadcast if they are not equal
# If a dimension of target shape has ITensor, we treat it as a broadcasted dim
if (
isinstance(i, int)
and i != DYNAMIC_DIM
and isinstance(o, int)
and o != DYNAMIC_DIM
):
stride.append(int(i == o))
elif isinstance(o, TRTTensor):
stride.append(0)
else:
# No broadcasting is happening. The output should have the same size as input at this dimension.
stride.append(1)

shape_ = shape_t
# Resolve dynamic dimensions in the target shape. These are not broadcasted dims.
# The value at this dimension should be same as input.
target_shape = []
for i in range(shape_rank):
if shape[i] == DYNAMIC_DIM:
target_shape.append(
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
)
else:
target_shape.append(shape[i])

target_shape_t = target_shape
# Handle dynamic shapes case where shape has dynamic dimension
if any(isinstance(ele, TRTTensor) for ele in shape_t):
shape_ = cat(
if any(isinstance(ele, TRTTensor) for ele in target_shape_t):
target_shape_t = cat(
ctx,
target,
source_ir,
name + "_shape_concat",
shape_t,
target_shape_t,
0,
cast_dtype=trt.int32,
)
Expand All @@ -302,10 +313,12 @@ def expand(
input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_tensor)
layer.set_input(2, shape_)
layer.set_input(2, target_shape_t)
layer.set_input(3, stride_tensor)
else:
layer = ctx.net.add_slice(input_t, start=start, shape=shape_, stride=stride)
layer = ctx.net.add_slice(
input_t, start=start, shape=target_shape_t, stride=stride
)

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
Expand Down
9 changes: 8 additions & 1 deletion tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,14 @@ def run_test_with_dynamic_shape(
)
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
inputs_max = [
(
spec.example_tensor("max_shape")
if spec.shape_mode == Input._ShapeMode.DYNAMIC
else spec.example_tensor()
)
for spec in input_specs
]
if not use_example_tensors:
inputs_max = [spec.torch_tensor for spec in input_specs]
super().run_test(mod, inputs_max, interp, rtol, atol, pyt_inputs=pyt_inputs)
32 changes: 29 additions & 3 deletions tests/py/dynamo/conversion/test_expand_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def forward(self, x):
("different_ranks", (1, 2, 1), (1, 2, 1), (2, 2, 1), (2, -1, -1, -1)),
]
)
def test_expand_dynamic(self, _, min_shape, opt_shape, max_shape, expanded_shape):
class ExpandDynamic(nn.Module):
def test_expand_dynamic_input(
self, _, min_shape, opt_shape, max_shape, expanded_shape
):
class ExpandInputDynamic(nn.Module):
def forward(self, x):
return torch.ops.aten.expand.default(x, expanded_shape)

Expand All @@ -51,10 +53,34 @@ def forward(self, x):
),
]
self.run_test_with_dynamic_shape(
ExpandDynamic(),
ExpandInputDynamic(),
input_specs,
)

@parameterized.expand(
[
("3d_dim", (4, 1, 768), (1, 1, 768)),
]
)
def test_expand_dynamic_target_shape(self, _, input_shape, weight_shape):
class ExpandTargetDynamic(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.cls_token = torch.nn.Parameter(torch.randn(weight_shape).cuda())

def forward(self, x):
batch_size = x.shape[0]
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, x), dim=0)
return embeddings

input_specs = [
Input(dtype=torch.float32, shape=input_shape),
]
self.run_test_with_dynamic_shape(
ExpandTargetDynamic(), input_specs, use_dynamo_tracer=True
)


if __name__ == "__main__":
run_tests()
26 changes: 26 additions & 0 deletions tests/py/dynamo/conversion/test_full_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ def forward(self, shape):
use_example_tensors=False,
)

@parameterized.expand(
[
((1, 5, 3), (3, 7, 3), (4, 10, 4), 0.11),
]
)
def test_full_dynamic_shape_list(self, min_shape, opt_shape, max_shape, fill_value):
class full(nn.Module):
def forward(self, x):
shape = x.shape[0]
target_shape = (shape, shape + 1)
return torch.ops.aten.full.default(target_shape, fill_value)

inputs = [
torch_tensorrt.Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.int64,
)
]
self.run_test_with_dynamic_shape(
full(),
inputs,
use_dynamo_tracer=True,
)


if __name__ == "__main__":
run_tests()
Loading