Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 2 additions & 82 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,13 @@
from typing import Callable, List, Optional, Set, Tuple

import torch
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.passes.infra.pass_base import PassResult
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt import Input
from torch_tensorrt.dynamo._settings import CompilationSettings

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
compose_bmm,
compose_chunk,
compose_getitem_slice,
remove_ops,
replace_aten_op_with_indices,
replace_aten_reshape_alias_with_replace,
replace_builtin_ops,
replace_native_layernorm_with_layernorm,
replace_transpose_mm_op_with_linear,
run_const_fold,
)
from torch_tensorrt.fx.passes.pass_utils import chain_passes

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,8 +46,6 @@ def run_test(
self,
mod,
inputs,
expected_ops,
unexpected_ops,
interpreter,
rtol,
atol,
Expand All @@ -75,10 +58,6 @@ def run_test(
cuda_inputs.append(i.cuda())

mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
if unexpected_ops:
self.assert_unexpected_op(mod, unexpected_ops)
start = time.perf_counter()
interpreter_result = interpreter.run(precision=precision)
sec = time.perf_counter() - start
Expand Down Expand Up @@ -214,75 +193,27 @@ def generate_graph(
self,
mod: torch.nn.Module,
original_inputs: List[torch.Tensor],
expected_ops: Set[Callable],
unexpected_ops: Optional[Set[Callable]] = None,
customized_passes: List[Callable] = None,
disable_passes: bool = False,
):
# Torchdynamo+aot proxytensor tracer
# Below are common passes
passes_list = [
compose_bmm,
compose_chunk,
compose_getitem_slice,
replace_aten_reshape_alias_with_replace,
replace_aten_op_with_indices,
replace_transpose_mm_op_with_linear, # after compose_bmm
replace_native_layernorm_with_layernorm,
remove_ops,
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
]
# Combine with customized passes specific to any model
if customized_passes:
passes_list.extend(customized_passes)

if disable_passes:
passes_list = []

fx_module, _ = aten_tracer.trace(mod, original_inputs)
for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module
fx_module(*original_inputs)

fx_module = run_const_fold(fx_module)
fx_module = torch.fx.symbolic_trace(mod)
_LOGGER.info(f"FX graph= {fx_module.graph}")

if len(expected_ops):
self.assert_has_op(fx_module, expected_ops)
if unexpected_ops:
self.assert_unexpected_op(fx_module, unexpected_ops)

return fx_module

def run_test(
self,
mod,
inputs,
expected_ops,
unexpected_ops=None,
apply_passes=None,
rtol=1e-03,
atol=1e-03,
precision=torch.float,
check_dtype=True,
disable_passes=False,
output_dtypes=None,
):
mod.eval()
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
)

if apply_passes is not None:
pass_tracer = chain_passes(*apply_passes)
mod = pass_tracer(mod, inputs)

# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_long_and_double=True)
Expand All @@ -296,8 +227,6 @@ def run_test(
super().run_test(
mod,
inputs,
expected_ops,
unexpected_ops,
interp,
rtol,
atol,
Expand All @@ -309,22 +238,15 @@ def run_test_with_dynamic_shape(
self,
mod,
input_specs,
expected_ops,
unexpected_ops=None,
rtol=1e-03,
atol=1e-03,
disable_passes=False,
output_dtypes=None,
):
mod.eval()
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
)

# Previous instance of the interpreter auto-casted 64-bit inputs
Expand All @@ -340,6 +262,4 @@ def run_test_with_dynamic_shape(
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(
mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol
)
super().run_test(mod, inputs_max, interp, rtol, atol)
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_abs_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAbsConverter(DispatchTestCase):
def test_abs_float(self, input_shape, dtype):
class abs(nn.Module):
def forward(self, input):
return torch.abs(input)
return torch.ops.aten.abs.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
abs(),
inputs,
expected_ops={torch.ops.aten.abs.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_abs_int(self, input_shape, dtype, low, high):
class abs(nn.Module):
def forward(self, input):
return torch.abs(input)
return torch.ops.aten.abs.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
abs(),
inputs,
expected_ops={torch.ops.aten.abs.default},
output_dtypes=[torch.int],
)

Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAcosConverter(DispatchTestCase):
def test_acos_float(self, input_shape, dtype):
class acos(nn.Module):
def forward(self, input):
return torch.acos(input)
return torch.ops.aten.acos.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
expected_ops={torch.ops.aten.acos.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_acos_int(self, input_shape, dtype, low, high):
class acos(nn.Module):
def forward(self, input):
return torch.acos(input)
return torch.ops.aten.acos.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
acos(),
inputs,
expected_ops={torch.ops.aten.acos.default},
)


Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_acosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAcoshConverter(DispatchTestCase):
def test_acosh_float(self, input_shape, dtype):
class acosh(nn.Module):
def forward(self, input):
return torch.acosh(input)
return torch.ops.aten.acosh.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
acosh(),
inputs,
expected_ops={torch.ops.aten.acosh.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_acosh_int(self, input_shape, dtype, low, high):
class acosh(nn.Module):
def forward(self, input):
return torch.acosh(input)
return torch.ops.aten.acosh.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
acosh(),
inputs,
expected_ops={torch.ops.aten.acosh.default},
)


Expand Down
12 changes: 4 additions & 8 deletions tests/py/dynamo/conversion/test_add_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ class TestAddConverter(DispatchTestCase):
def test_add_tensor(self, _, shape):
class add(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.add(lhs_val, rhs_val)
return torch.ops.aten.add.Tensor(lhs_val, rhs_val)

inputs = [torch.randn(shape), torch.randn(shape)]
self.run_test(
add(),
inputs,
expected_ops={torch.ops.aten.add.Tensor},
)

@parameterized.expand(
Expand All @@ -35,13 +34,12 @@ def forward(self, lhs_val, rhs_val):
def test_add_tensor_alpha(self, _, shape, alpha):
class add(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.add(lhs_val, rhs_val, alpha=alpha)
return torch.ops.aten.add.Tensor(lhs_val, rhs_val, alpha=alpha)

inputs = [torch.randn(shape), torch.randn(shape)]
self.run_test(
add(),
inputs,
expected_ops={torch.ops.aten.add.Tensor},
)

@parameterized.expand(
Expand All @@ -53,13 +51,12 @@ def forward(self, lhs_val, rhs_val):
def test_add_scalar(self, _, shape, scalar):
class add(nn.Module):
def forward(self, lhs_val):
return torch.add(lhs_val, scalar)
return torch.ops.aten.add.Tensor(lhs_val, scalar)

inputs = [torch.randn(shape)]
self.run_test(
add(),
inputs,
expected_ops={torch.ops.aten.add.Tensor},
)

@parameterized.expand(
Expand All @@ -71,13 +68,12 @@ def forward(self, lhs_val):
def test_add_scalar_alpha(self, _, shape, scalar, alpha):
class add(nn.Module):
def forward(self, lhs_val):
return torch.add(lhs_val, scalar, alpha=alpha)
return torch.ops.aten.add.Tensor(lhs_val, scalar, alpha=alpha)

inputs = [torch.randn(shape)]
self.run_test(
add(),
inputs,
expected_ops={torch.ops.aten.add.Tensor},
)


Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_asin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAsinConverter(DispatchTestCase):
def test_asin_float(self, input_shape, dtype):
class asin(nn.Module):
def forward(self, input):
return torch.asin(input)
return torch.ops.aten.asin.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
asin(),
inputs,
expected_ops={torch.ops.aten.asin.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_asin_int(self, input_shape, dtype, low, high):
class asin(nn.Module):
def forward(self, input):
return torch.asin(input)
return torch.ops.aten.asin.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
asin(),
inputs,
expected_ops={torch.ops.aten.asin.default},
)


Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/test_asinh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ class TestAsinhConverter(DispatchTestCase):
def test_asinh_float(self, input_shape, dtype):
class asinh(nn.Module):
def forward(self, input):
return torch.asinh(input)
return torch.ops.aten.asinh.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
asinh(),
inputs,
expected_ops={torch.ops.aten.asinh.default},
)

@parameterized.expand(
Expand All @@ -37,13 +36,12 @@ def forward(self, input):
def test_asinh_int(self, input_shape, dtype, low, high):
class asinh(nn.Module):
def forward(self, input):
return torch.asinh(input)
return torch.ops.aten.asinh.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
asinh(),
inputs,
expected_ops={torch.ops.aten.asinh.default},
)


Expand Down
Loading