File tree Expand file tree Collapse file tree 2 files changed +7
-6
lines changed
py/torch_tensorrt/dynamo/conversion/impl
tests/py/dynamo/conversion Expand file tree Collapse file tree 2 files changed +7
-6
lines changed Original file line number Diff line number Diff line change 17
17
to_numpy ,
18
18
)
19
19
from torch_tensorrt .fx .converters .converter_utils import set_layer_name
20
-
21
- import tensorrt as trt
20
+ from torch_tensorrt .fx .types import TRTTensor
22
21
23
22
24
23
def embedding (
@@ -31,6 +30,10 @@ def embedding(
31
30
) -> TRTTensor :
32
31
indices_tensor = input
33
32
embedding_tensor = weight
33
+ if isinstance (indices_tensor , torch .Tensor ) and indices_tensor .dtype == torch .int64 :
34
+ raise RuntimeError (
35
+ "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
36
+ )
34
37
indices_tensor = get_trt_tensor (ctx , indices_tensor , f"{ name } _indices_tensor" )
35
38
embedding_tensor = get_trt_tensor (ctx , embedding_tensor , f"{ name } _embedding_tensor" )
36
39
# unsupported parameters
Original file line number Diff line number Diff line change @@ -62,7 +62,6 @@ def run_test(
62
62
cuda_inputs .append (i .cuda ())
63
63
64
64
mod .eval ()
65
- mod = mod .cuda ()
66
65
start = time .perf_counter ()
67
66
interpreter_result = interpreter .run ()
68
67
sec = time .perf_counter () - start
@@ -73,6 +72,7 @@ def run_test(
73
72
interpreter_result .output_names ,
74
73
)
75
74
75
+ mod = mod .cuda ()
76
76
ref_outputs = mod (* cuda_inputs )
77
77
78
78
torch .cuda .synchronize ()
@@ -96,11 +96,9 @@ def run_test(
96
96
):
97
97
ref_outputs = [ref_outputs ]
98
98
for out , ref in zip (outputs , ref_outputs ):
99
- ref = ref .cpu () # to_dtype test has cases with gpu output
100
99
if not isinstance (ref , torch .Tensor ):
101
100
ref = torch .tensor ([ref ])
102
- if ref .dtype == torch .int64 :
103
- ref = ref .int () # convert torch.max's index output tensor to int32
101
+ ref = ref .cpu () # to_dtype test has cases with gpu output
104
102
torch .testing .assert_close (
105
103
out .cpu (),
106
104
ref ,
You can’t perform that action at this time.
0 commit comments