@@ -513,6 +513,11 @@ def test_qnn_backend_log(self):
513
513
sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
514
514
self .lower_module_and_test_output (module , sample_input )
515
515
516
+ def test_qnn_backend_logical_not (self ):
517
+ module = LogicalNot () # noqa: F405
518
+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
519
+ self .lower_module_and_test_output (module , sample_input )
520
+
516
521
def test_qnn_backend_log_softmax (self ):
517
522
module = LogSoftmax () # noqa: F405
518
523
sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -692,6 +697,18 @@ def test_qnn_backend_view(self):
692
697
sample_input = (torch .randn ([1 , 8 , 512 ]), torch .randn ([1 , 2 , 8 , 256 ]))
693
698
self .lower_module_and_test_output (module , sample_input )
694
699
700
+ def test_qnn_backend_where (self ):
701
+ modules = [
702
+ Where (), # noqa: F405
703
+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
704
+ ]
705
+ sample_inputs = [
706
+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
707
+ (torch .randn (3 , 2 ),),
708
+ ]
709
+ for i , module in enumerate (modules ):
710
+ self .lower_module_and_test_output (module , sample_inputs [i ])
711
+
695
712
696
713
class TestQNNFloatingPointModel (TestQNN ):
697
714
# TODO: refactor to support different backends
@@ -1396,6 +1413,12 @@ def test_qnn_backend_log(self):
1396
1413
module = self .get_qdq_module (module , sample_input )
1397
1414
self .lower_module_and_test_output (module , sample_input )
1398
1415
1416
+ def test_qnn_backend_logical_not (self ):
1417
+ module = LogicalNot () # noqa: F405
1418
+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
1419
+ module = self .get_qdq_module (module , sample_input )
1420
+ self .lower_module_and_test_output (module , sample_input )
1421
+
1399
1422
def test_qnn_backend_log_softmax (self ):
1400
1423
module = LogSoftmax () # noqa: F405
1401
1424
sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -1609,6 +1632,19 @@ def test_qnn_backend_view(self):
1609
1632
module = self .get_qdq_module (module , sample_input )
1610
1633
self .lower_module_and_test_output (module , sample_input )
1611
1634
1635
+ def test_qnn_backend_where (self ):
1636
+ modules = [
1637
+ Where (), # noqa: F405
1638
+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
1639
+ ]
1640
+ sample_inputs = [
1641
+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
1642
+ (torch .randn (3 , 2 ),),
1643
+ ]
1644
+ for i , module in enumerate (modules ):
1645
+ module = self .get_qdq_module (module , sample_inputs [i ])
1646
+ self .lower_module_and_test_output (module , sample_inputs [i ])
1647
+
1612
1648
1613
1649
class TestQNNQuantizedModel (TestQNN ):
1614
1650
# TODO: refactor to support different backends
0 commit comments