1- from typing import List , Optional , Union
1+ from typing import List , Optional , Sequence , Union
22
3+ import numpy as np
4+ import torch
35from torch .fx .node import Target
4- from torch_tensorrt .dynamo .conversion .converter_utils import SourceIR
5- from torch_tensorrt . fx . converters . converter_utils import (
6+ from torch_tensorrt .dynamo .conversion .converter_utils import (
7+ SourceIR ,
68 get_positive_dim ,
7- set_layer_name ,
9+ get_trt_tensor ,
810)
11+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
912from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
1013
1114
@@ -14,9 +17,11 @@ def reshape(
1417 target : Union [Target , str ],
1518 source_ir : Optional [SourceIR ],
1619 name : str ,
17- input : TRTTensor ,
20+ input : Sequence [ Union [ TRTTensor , torch . Tensor , np . ndarray ]] ,
1821 shape : List [int ],
1922) -> TRTTensor :
23+ if not isinstance (input , TRTTensor ):
24+ input = get_trt_tensor (network , input , f"{ name } _input" )
2025 layer = network .add_shuffle (input )
2126 layer .reshape_dims = tuple (shape )
2227 set_layer_name (layer , target , f"{ name } _reshape" , source_ir )
@@ -28,7 +33,7 @@ def flatten(
2833 target : Union [Target , str ],
2934 source_ir : Optional [SourceIR ],
3035 name : str ,
31- input : TRTTensor ,
36+ input : Sequence [ Union [ TRTTensor , torch . Tensor , np . ndarray ]] ,
3237 start_dim : int ,
3338 end_dim : int ,
3439) -> TRTTensor :
@@ -37,6 +42,9 @@ def flatten(
3742 start_dim = get_positive_dim (start_dim , dim_size )
3843 end_dim = get_positive_dim (end_dim , dim_size )
3944
45+ if not isinstance (input , TRTTensor ):
46+ input = get_trt_tensor (network , input , f"{ name } _input" )
47+
4048 num_elements = 1
4149 for i in range (start_dim , end_dim + 1 ):
4250 num_elements *= shape [i ]
0 commit comments