Skip to content

Commit d212e51

Browse files
committed
update
1 parent 12840e2 commit d212e51

File tree

1 file changed

+14
-6
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+14
-6
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional, Sequence, Union
22

3+
import numpy as np
4+
import torch
35
from torch.fx.node import Target
4-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
5-
from torch_tensorrt.fx.converters.converter_utils import (
6+
from torch_tensorrt.dynamo.conversion.converter_utils import (
7+
SourceIR,
68
get_positive_dim,
7-
set_layer_name,
9+
get_trt_tensor,
810
)
11+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
912
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1013

1114

@@ -14,9 +17,11 @@ def reshape(
1417
target: Union[Target, str],
1518
source_ir: Optional[SourceIR],
1619
name: str,
17-
input: TRTTensor,
20+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
1821
shape: List[int],
1922
) -> TRTTensor:
23+
if not isinstance(input, TRTTensor):
24+
input = get_trt_tensor(network, input, f"{name}_input")
2025
layer = network.add_shuffle(input)
2126
layer.reshape_dims = tuple(shape)
2227
set_layer_name(layer, target, f"{name}_reshape", source_ir)
@@ -28,7 +33,7 @@ def flatten(
2833
target: Union[Target, str],
2934
source_ir: Optional[SourceIR],
3035
name: str,
31-
input: TRTTensor,
36+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
3237
start_dim: int,
3338
end_dim: int,
3439
) -> TRTTensor:
@@ -37,6 +42,9 @@ def flatten(
3742
start_dim = get_positive_dim(start_dim, dim_size)
3843
end_dim = get_positive_dim(end_dim, dim_size)
3944

45+
if not isinstance(input, TRTTensor):
46+
input = get_trt_tensor(network, input, f"{name}_input")
47+
4048
num_elements = 1
4149
for i in range(start_dim, end_dim + 1):
4250
num_elements *= shape[i]

0 commit comments

Comments
 (0)