|
3 | 3 | import keras |
4 | 4 | import pytest |
5 | 5 |
|
| 6 | +# OpenVINO supported test paths |
| 7 | +OPENVINO_SUPPORTED_PATHS = [ |
| 8 | + "keras-hub/integration_tests", |
| 9 | + "keras_hub/src/models/gemma", |
| 10 | + "keras_hub/src/models/gpt2", |
| 11 | + "keras_hub/src/models/mistral", |
| 12 | + "keras_hub/src/tokenizers", |
| 13 | +] |
| 14 | + |
| 15 | +# OpenVINO specific test skips |
| 16 | +OPENVINO_SPECIFIC_SKIPPING_TESTS = { |
| 17 | + "test_backbone_basics": "bfloat16 dtype not supported", |
| 18 | + "test_score_loss": "Non-implemented roll operation", |
| 19 | + "test_causal_lm_basics": "requires trainable backend", |
| 20 | +} |
| 21 | + |
6 | 22 |
|
7 | 23 | def pytest_addoption(parser): |
8 | 24 | parser.addoption( |
@@ -34,15 +50,6 @@ def pytest_addoption(parser): |
34 | 50 | def pytest_configure(config): |
35 | 51 | # Monkey-patch training methods for OpenVINO backend |
36 | 52 | if keras.config.backend() == "openvino": |
37 | | - # Store original methods in case we need to restore them |
38 | | - if not hasattr(keras.Model, "_original_compile"): |
39 | | - keras.Model._original_compile = keras.Model.compile |
40 | | - keras.Model._original_fit = keras.Model.fit |
41 | | - keras.Model._original_train_on_batch = keras.Model.train_on_batch |
42 | | - |
43 | | - keras.Model.compile = lambda *args, **kwargs: pytest.skip( |
44 | | - "Model.compile() not supported on OpenVINO backend" |
45 | | - ) |
46 | 53 | keras.Model.fit = lambda *args, **kwargs: pytest.skip( |
47 | 54 | "Model.fit() not supported on OpenVINO backend" |
48 | 55 | ) |
@@ -131,48 +138,22 @@ def pytest_collection_modifyitems(config, items): |
131 | 138 | # OpenVINO-specific test skipping |
132 | 139 | if keras.config.backend() == "openvino": |
133 | 140 | test_name = item.name.split("[")[0] |
134 | | - test_path = str(item.fspath) |
135 | | - |
136 | | - # OpenVINO supported test paths |
137 | | - openvino_supported_paths = [ |
138 | | - "keras-hub/integration_tests", |
139 | | - "keras_hub/src/models/gemma", |
140 | | - "keras_hub/src/models/gpt2", |
141 | | - "keras_hub/src/models/mistral", |
142 | | - "keras_hub/src/samplers/serialization_test.py", |
143 | | - "keras_hub/src/tests/doc_tests/docstring_test.py", |
144 | | - "keras_hub/src/tokenizers", |
145 | | - "keras_hub/src/utils", |
146 | | - ] |
147 | | - |
148 | | - # Skip specific problematic test methods |
149 | | - specific_skipping_tests = { |
150 | | - "test_backbone_basics": "Requires trainable backend", |
151 | | - "test_score_loss": "Non-implemented roll operation", |
152 | | - "test_layer_behaviors": "Requires trainable backend", |
153 | | - } |
154 | | - |
155 | | - if test_name in specific_skipping_tests: |
| 141 | + |
| 142 | + if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS: |
156 | 143 | item.add_marker( |
157 | 144 | pytest.mark.skipif( |
158 | 145 | True, |
159 | 146 | reason="OpenVINO: " |
160 | | - f"{specific_skipping_tests[test_name]}", |
| 147 | + f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}", |
161 | 148 | ) |
162 | 149 | ) |
163 | 150 | continue |
164 | 151 |
|
165 | | - parts = test_path.replace("\\", "/").split("/") |
166 | | - try: |
167 | | - keras_hub_idx = parts.index("keras_hub") |
168 | | - relative_test_path = "/".join(parts[keras_hub_idx:]) |
169 | | - except ValueError: |
170 | | - relative_test_path = test_path |
171 | | - |
172 | 152 | is_whitelisted = any( |
173 | | - relative_test_path == supported_path |
174 | | - or relative_test_path.startswith(supported_path + "/") |
175 | | - for supported_path in openvino_supported_paths |
| 153 | + item.nodeid.startswith(supported_path + "/") |
| 154 | + or item.nodeid.startswith(supported_path + "::") |
| 155 | + or item.nodeid == supported_path |
| 156 | + for supported_path in OPENVINO_SUPPORTED_PATHS |
176 | 157 | ) |
177 | 158 |
|
178 | 159 | if not is_whitelisted: |
|
0 commit comments