Skip to content

Commit 17a9903

Browse files
authored
Merge pull request #1438 from pytorch/noxfile_update
chore: Nox file update from NGC 22.11 release
2 parents 7b37ada + a91ba8a commit 17a9903

File tree

1 file changed

+149
-3
lines changed

1 file changed

+149
-3
lines changed

noxfile.py

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Use system installed Python packages
77
PYT_PATH = (
8-
"/opt/conda/lib/python3.8/site-packages"
8+
"/usr/local/lib/python3.8/dist-packages"
99
if not "PYT_PATH" in os.environ
1010
else os.environ["PYT_PATH"]
1111
)
@@ -202,6 +202,93 @@ def run_base_tests(session):
202202
else:
203203
session.run_always("pytest", test)
204204

205+
def run_fx_core_tests(session):
206+
print("Running FX core tests")
207+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
208+
tests = [
209+
"core",
210+
]
211+
for test in tests:
212+
if USE_HOST_DEPS:
213+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
214+
else:
215+
session.run_always("pytest", test)
216+
217+
def run_fx_converter_tests(session):
218+
print("Running FX converter tests")
219+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
220+
tests = [
221+
"converters",
222+
]
223+
# Skipping this test as it fails inside NGC container with the following error.
224+
# Error Code 4: Internal Error (Could not find any implementation for node conv due to insufficient workspace. See verbose log for requested sizes.)
225+
skip_tests = "-k not conv3d"
226+
for test in tests:
227+
if USE_HOST_DEPS:
228+
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
229+
else:
230+
session.run_always("pytest", test, skip_tests)
231+
232+
def run_fx_lower_tests(session):
233+
print("Running FX passes and trt_lower tests")
234+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
235+
tests = [
236+
"passes/test_multi_fuse_trt.py",
237+
# "passes/test_fuse_permute_linear_trt.py",
238+
"passes/test_remove_duplicate_output_args.py",
239+
"passes/test_fuse_permute_matmul_trt.py",
240+
#"passes/test_graph_opts.py"
241+
"trt_lower",
242+
]
243+
for test in tests:
244+
if USE_HOST_DEPS:
245+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
246+
else:
247+
session.run_always("pytest", test)
248+
249+
def run_fx_quant_tests(session):
250+
print("Running FX Quant tests")
251+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
252+
tests = [
253+
"quant",
254+
]
255+
# Skipping this test as it fails inside NGC container with the following error.
256+
# ImportError: cannot import name 'ObservationType' from 'torch.ao.quantization.backend_config.observation_type'
257+
skip_tests = "-k not conv_add_standalone_module"
258+
for test in tests:
259+
if USE_HOST_DEPS:
260+
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
261+
else:
262+
session.run_always("pytest", test, skip_tests)
263+
264+
def run_fx_tracer_tests(session):
265+
print("Running FX Tracer tests")
266+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
267+
# skipping a test since it depends on torchdynamo
268+
# Enable this test once NGC moves to latest pytorch which has dynamo integrated.
269+
tests = [
270+
"tracer/test_acc_shape_prop.py",
271+
"tracer/test_acc_tracer.py",
272+
#"tracer/test_dispatch_tracer.py"
273+
]
274+
for test in tests:
275+
if USE_HOST_DEPS:
276+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
277+
else:
278+
session.run_always("pytest", test)
279+
280+
def run_fx_tools_tests(session):
281+
print("Running FX tools tests")
282+
session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test"))
283+
tests = [
284+
"tools",
285+
]
286+
for test in tests:
287+
if USE_HOST_DEPS:
288+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
289+
else:
290+
session.run_always("pytest", test)
291+
205292

206293
def run_model_tests(session):
207294
print("Running model tests")
@@ -309,6 +396,35 @@ def run_l0_api_tests(session):
309396
run_base_tests(session)
310397
cleanup(session)
311398

399+
def run_l0_fx_tests(session):
400+
if not USE_HOST_DEPS:
401+
install_deps(session)
402+
install_torch_trt(session)
403+
run_fx_core_tests(session)
404+
run_fx_converter_tests(session)
405+
run_fx_lower_tests(session)
406+
cleanup(session)
407+
408+
def run_l0_fx_core_tests(session):
409+
if not USE_HOST_DEPS:
410+
install_deps(session)
411+
install_torch_trt(session)
412+
run_fx_core_tests(session)
413+
cleanup(session)
414+
415+
def run_l0_fx_converter_tests(session):
416+
if not USE_HOST_DEPS:
417+
install_deps(session)
418+
install_torch_trt(session)
419+
run_fx_converter_tests(session)
420+
cleanup(session)
421+
422+
def run_l0_fx_lower_tests(session):
423+
if not USE_HOST_DEPS:
424+
install_deps(session)
425+
install_torch_trt(session)
426+
run_fx_lower_tests(session)
427+
cleanup(session)
312428

313429
def run_l0_dla_tests(session):
314430
if not USE_HOST_DEPS:
@@ -327,7 +443,6 @@ def run_l1_model_tests(session):
327443
run_model_tests(session)
328444
cleanup(session)
329445

330-
331446
def run_l1_int8_accuracy_tests(session):
332447
if not USE_HOST_DEPS:
333448
install_deps(session)
@@ -337,6 +452,14 @@ def run_l1_int8_accuracy_tests(session):
337452
run_int8_accuracy_tests(session)
338453
cleanup(session)
339454

455+
def run_l1_fx_tests(session):
456+
if not USE_HOST_DEPS:
457+
install_deps(session)
458+
install_torch_trt(session)
459+
run_fx_quant_tests(session)
460+
run_fx_tracer_tests(session)
461+
run_fx_tools_tests(session)
462+
cleanup(session)
340463

341464
def run_l2_trt_compatibility_tests(session):
342465
if not USE_HOST_DEPS:
@@ -360,6 +483,25 @@ def l0_api_tests(session):
360483
"""When a developer needs to check correctness for a PR or something"""
361484
run_l0_api_tests(session)
362485

486+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
487+
def l0_fx_tests(session):
488+
"""When a developer needs to check correctness for a PR or something"""
489+
run_l0_fx_tests(session)
490+
491+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
492+
def l0_fx_core_tests(session):
493+
"""When a developer needs to check correctness for a PR or something"""
494+
run_l0_fx_core_tests(session)
495+
496+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
497+
def l0_fx_converter_tests(session):
498+
"""When a developer needs to check correctness for a PR or something"""
499+
run_l0_fx_converter_tests(session)
500+
501+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
502+
def l0_fx_lower_tests(session):
503+
"""When a developer needs to check correctness for a PR or something"""
504+
run_l0_fx_lower_tests(session)
363505

364506
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
365507
def l0_dla_tests(session):
@@ -372,6 +514,10 @@ def l1_model_tests(session):
372514
"""When a user needs to test the functionality of standard models compilation and results"""
373515
run_l1_model_tests(session)
374516

517+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
518+
def l1_fx_tests(session):
519+
"""When a user needs to test the functionality of standard models compilation and results"""
520+
run_l1_fx_tests(session)
375521

376522
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
377523
def l1_int8_accuracy_tests(session):
@@ -388,4 +534,4 @@ def l2_trt_compatibility_tests(session):
388534
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
389535
def l2_multi_gpu_tests(session):
390536
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
391-
run_l2_multi_gpu_tests(session)
537+
run_l2_multi_gpu_tests(session)

0 commit comments

Comments
 (0)