Skip to content

Commit 1f58d47

Browse files
committed
rebase
1 parent a95a217 commit 1f58d47

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
to_numpy,
1818
)
1919
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
20-
21-
import tensorrt as trt
20+
from torch_tensorrt.fx.types import TRTTensor
2221

2322

2423
def embedding(
@@ -31,6 +30,10 @@ def embedding(
3130
) -> TRTTensor:
3231
indices_tensor = input
3332
embedding_tensor = weight
33+
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
34+
raise RuntimeError(
35+
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
36+
)
3437
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
3538
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
3639
# unsupported parameters

tests/py/dynamo/conversion/harness.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def run_test(
6262
cuda_inputs.append(i.cuda())
6363

6464
mod.eval()
65-
mod = mod.cuda()
6665
start = time.perf_counter()
6766
interpreter_result = interpreter.run()
6867
sec = time.perf_counter() - start
@@ -73,6 +72,7 @@ def run_test(
7372
interpreter_result.output_names,
7473
)
7574

75+
mod = mod.cuda()
7676
ref_outputs = mod(*cuda_inputs)
7777

7878
torch.cuda.synchronize()
@@ -96,11 +96,9 @@ def run_test(
9696
):
9797
ref_outputs = [ref_outputs]
9898
for out, ref in zip(outputs, ref_outputs):
99-
ref = ref.cpu() # to_dtype test has cases with gpu output
10099
if not isinstance(ref, torch.Tensor):
101100
ref = torch.tensor([ref])
102-
if ref.dtype == torch.int64:
103-
ref = ref.int() # convert torch.max's index output tensor to int32
101+
ref = ref.cpu() # to_dtype test has cases with gpu output
104102
torch.testing.assert_close(
105103
out.cpu(),
106104
ref,

0 commit comments

Comments
 (0)