5
5
6
6
# Use system installed Python packages
7
7
PYT_PATH = (
8
- "/opt/conda /lib/python3.8/site -packages"
8
+ "/usr/local /lib/python3.8/dist -packages"
9
9
if not "PYT_PATH" in os .environ
10
10
else os .environ ["PYT_PATH" ]
11
11
)
@@ -202,6 +202,93 @@ def run_base_tests(session):
202
202
else :
203
203
session .run_always ("pytest" , test )
204
204
205
+ def run_fx_core_tests (session ):
206
+ print ("Running FX core tests" )
207
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
208
+ tests = [
209
+ "core" ,
210
+ ]
211
+ for test in tests :
212
+ if USE_HOST_DEPS :
213
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
214
+ else :
215
+ session .run_always ("pytest" , test )
216
+
217
+ def run_fx_converter_tests (session ):
218
+ print ("Running FX converter tests" )
219
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
220
+ tests = [
221
+ "converters" ,
222
+ ]
223
+ # Skipping this test as it fails inside NGC container with the following error.
224
+ # Error Code 4: Internal Error (Could not find any implementation for node conv due to insufficient workspace. See verbose log for requested sizes.)
225
+ skip_tests = "-k not conv3d"
226
+ for test in tests :
227
+ if USE_HOST_DEPS :
228
+ session .run_always ("pytest" , test , skip_tests , env = {"PYTHONPATH" : PYT_PATH })
229
+ else :
230
+ session .run_always ("pytest" , test , skip_tests )
231
+
232
+ def run_fx_lower_tests (session ):
233
+ print ("Running FX passes and trt_lower tests" )
234
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
235
+ tests = [
236
+ "passes/test_multi_fuse_trt.py" ,
237
+ # "passes/test_fuse_permute_linear_trt.py",
238
+ "passes/test_remove_duplicate_output_args.py" ,
239
+ "passes/test_fuse_permute_matmul_trt.py" ,
240
+ #"passes/test_graph_opts.py"
241
+ "trt_lower" ,
242
+ ]
243
+ for test in tests :
244
+ if USE_HOST_DEPS :
245
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
246
+ else :
247
+ session .run_always ("pytest" , test )
248
+
249
+ def run_fx_quant_tests (session ):
250
+ print ("Running FX Quant tests" )
251
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
252
+ tests = [
253
+ "quant" ,
254
+ ]
255
+ # Skipping this test as it fails inside NGC container with the following error.
256
+ # ImportError: cannot import name 'ObservationType' from 'torch.ao.quantization.backend_config.observation_type'
257
+ skip_tests = "-k not conv_add_standalone_module"
258
+ for test in tests :
259
+ if USE_HOST_DEPS :
260
+ session .run_always ("pytest" , test , skip_tests , env = {"PYTHONPATH" : PYT_PATH })
261
+ else :
262
+ session .run_always ("pytest" , test , skip_tests )
263
+
264
+ def run_fx_tracer_tests (session ):
265
+ print ("Running FX Tracer tests" )
266
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
267
+ # skipping a test since it depends on torchdynamo
268
+ # Enable this test once NGC moves to latest pytorch which has dynamo integrated.
269
+ tests = [
270
+ "tracer/test_acc_shape_prop.py" ,
271
+ "tracer/test_acc_tracer.py" ,
272
+ #"tracer/test_dispatch_tracer.py"
273
+ ]
274
+ for test in tests :
275
+ if USE_HOST_DEPS :
276
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
277
+ else :
278
+ session .run_always ("pytest" , test )
279
+
280
+ def run_fx_tools_tests (session ):
281
+ print ("Running FX tools tests" )
282
+ session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
283
+ tests = [
284
+ "tools" ,
285
+ ]
286
+ for test in tests :
287
+ if USE_HOST_DEPS :
288
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
289
+ else :
290
+ session .run_always ("pytest" , test )
291
+
205
292
206
293
def run_model_tests (session ):
207
294
print ("Running model tests" )
@@ -309,6 +396,35 @@ def run_l0_api_tests(session):
309
396
run_base_tests (session )
310
397
cleanup (session )
311
398
399
+ def run_l0_fx_tests (session ):
400
+ if not USE_HOST_DEPS :
401
+ install_deps (session )
402
+ install_torch_trt (session )
403
+ run_fx_core_tests (session )
404
+ run_fx_converter_tests (session )
405
+ run_fx_lower_tests (session )
406
+ cleanup (session )
407
+
408
+ def run_l0_fx_core_tests (session ):
409
+ if not USE_HOST_DEPS :
410
+ install_deps (session )
411
+ install_torch_trt (session )
412
+ run_fx_core_tests (session )
413
+ cleanup (session )
414
+
415
+ def run_l0_fx_converter_tests (session ):
416
+ if not USE_HOST_DEPS :
417
+ install_deps (session )
418
+ install_torch_trt (session )
419
+ run_fx_converter_tests (session )
420
+ cleanup (session )
421
+
422
+ def run_l0_fx_lower_tests (session ):
423
+ if not USE_HOST_DEPS :
424
+ install_deps (session )
425
+ install_torch_trt (session )
426
+ run_fx_lower_tests (session )
427
+ cleanup (session )
312
428
313
429
def run_l0_dla_tests (session ):
314
430
if not USE_HOST_DEPS :
@@ -327,7 +443,6 @@ def run_l1_model_tests(session):
327
443
run_model_tests (session )
328
444
cleanup (session )
329
445
330
-
331
446
def run_l1_int8_accuracy_tests (session ):
332
447
if not USE_HOST_DEPS :
333
448
install_deps (session )
@@ -337,6 +452,14 @@ def run_l1_int8_accuracy_tests(session):
337
452
run_int8_accuracy_tests (session )
338
453
cleanup (session )
339
454
455
+ def run_l1_fx_tests (session ):
456
+ if not USE_HOST_DEPS :
457
+ install_deps (session )
458
+ install_torch_trt (session )
459
+ run_fx_quant_tests (session )
460
+ run_fx_tracer_tests (session )
461
+ run_fx_tools_tests (session )
462
+ cleanup (session )
340
463
341
464
def run_l2_trt_compatibility_tests (session ):
342
465
if not USE_HOST_DEPS :
@@ -360,6 +483,25 @@ def l0_api_tests(session):
360
483
"""When a developer needs to check correctness for a PR or something"""
361
484
run_l0_api_tests (session )
362
485
486
+ @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
487
+ def l0_fx_tests (session ):
488
+ """When a developer needs to check correctness for a PR or something"""
489
+ run_l0_fx_tests (session )
490
+
491
+ @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
492
+ def l0_fx_core_tests (session ):
493
+ """When a developer needs to check correctness for a PR or something"""
494
+ run_l0_fx_core_tests (session )
495
+
496
+ @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
497
+ def l0_fx_converter_tests (session ):
498
+ """When a developer needs to check correctness for a PR or something"""
499
+ run_l0_fx_converter_tests (session )
500
+
501
+ @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
502
+ def l0_fx_lower_tests (session ):
503
+ """When a developer needs to check correctness for a PR or something"""
504
+ run_l0_fx_lower_tests (session )
363
505
364
506
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
365
507
def l0_dla_tests (session ):
@@ -372,6 +514,10 @@ def l1_model_tests(session):
372
514
"""When a user needs to test the functionality of standard models compilation and results"""
373
515
run_l1_model_tests (session )
374
516
517
+ @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
518
+ def l1_fx_tests (session ):
519
+ """When a user needs to test the functionality of standard models compilation and results"""
520
+ run_l1_fx_tests (session )
375
521
376
522
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
377
523
def l1_int8_accuracy_tests (session ):
@@ -388,4 +534,4 @@ def l2_trt_compatibility_tests(session):
388
534
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
389
535
def l2_multi_gpu_tests (session ):
390
536
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
391
- run_l2_multi_gpu_tests (session )
537
+ run_l2_multi_gpu_tests (session )
0 commit comments