11import os
2- from pathlib import Path
32
43import keras
54import pytest
65
7- from keras_hub .src .utils .openvino_utils import get_openvino_skip_reason
8- from keras_hub .src .utils .openvino_utils import setup_openvino_test_config
9-
106
117def pytest_addoption (parser ):
128 parser .addoption (
@@ -33,16 +29,27 @@ def pytest_addoption(parser):
3329 default = False ,
3430 help = "fail if a gpu is not present" ,
3531 )
36- parser .addoption (
37- "--auto_skip_training" ,
38- action = "store_true" ,
39- default = True ,
40- help = "automatically skip tests with "
41- "training methods on non-trainable backends" ,
42- )
4332
4433
4534def pytest_configure (config ):
35+ # Monkey-patch training methods for OpenVINO backend
36+ if keras .config .backend () == "openvino" :
37+ # Store original methods in case we need to restore them
38+ if not hasattr (keras .Model , "_original_compile" ):
39+ keras .Model ._original_compile = keras .Model .compile
40+ keras .Model ._original_fit = keras .Model .fit
41+ keras .Model ._original_train_on_batch = keras .Model .train_on_batch
42+
43+ keras .Model .compile = lambda * args , ** kwargs : pytest .skip (
44+ "Model.compile() not supported on OpenVINO backend"
45+ )
46+ keras .Model .fit = lambda * args , ** kwargs : pytest .skip (
47+ "Model.fit() not supported on OpenVINO backend"
48+ )
49+ keras .Model .train_on_batch = lambda * args , ** kwargs : pytest .skip (
50+ "Model.train_on_batch() not supported on OpenVINO backend"
51+ )
52+
4653 # Verify that device has GPU and detected by backend
4754 if config .getoption ("--check_gpu" ):
4855 found_gpu = False
@@ -84,12 +91,9 @@ def pytest_configure(config):
8491
8592
8693def pytest_collection_modifyitems (config , items ):
87- openvino_supported_paths = None
88-
8994 run_extra_large_tests = config .getoption ("--run_extra_large" )
9095 # Run large tests for --run_extra_large or --run_large.
9196 run_large_tests = config .getoption ("--run_large" ) or run_extra_large_tests
92- auto_skip_training = config .getoption ("--auto_skip_training" )
9397
9498 # Messages to annotate skipped tests with.
9599 skip_large = pytest .mark .skipif (
@@ -124,21 +128,58 @@ def pytest_collection_modifyitems(config, items):
124128 if "kaggle_key_required" in item .keywords :
125129 item .add_marker (kaggle_key_required )
126130
127- # OpenVINO-specific skipping logic - whitelist-based approach
131+ # OpenVINO-specific test skipping
128132 if keras .config .backend () == "openvino" :
129- # OpenVINO backend configuration
130- if openvino_supported_paths is None :
131- openvino_supported_paths = setup_openvino_test_config (
132- str (Path (__file__ ).parent )
133+ test_name = item .name .split ("[" )[0 ]
134+ test_path = str (item .fspath )
135+
136+ # OpenVINO supported test paths
137+ openvino_supported_paths = [
138+ "keras-hub/integration_tests" ,
139+ "keras_hub/src/models/gemma" ,
140+ "keras_hub/src/models/gpt2" ,
141+ "keras_hub/src/models/mistral" ,
142+ "keras_hub/src/samplers/serialization_test.py" ,
143+ "keras_hub/src/tests/doc_tests/docstring_test.py" ,
144+ "keras_hub/src/tokenizers" ,
145+ "keras_hub/src/utils" ,
146+ ]
147+
148+ # Skip specific problematic test methods
149+ specific_skipping_tests = {
150+ "test_backbone_basics" : "Requires trainable backend" ,
151+ "test_score_loss" : "Non-implemented roll operation" ,
152+ "test_layer_behaviors" : "Requires trainable backend" ,
153+ }
154+
155+ if test_name in specific_skipping_tests :
156+ item .add_marker (
157+ pytest .mark .skipif (
158+ True ,
159+ reason = "OpenVINO: "
160+ f"{ specific_skipping_tests [test_name ]} " ,
161+ )
133162 )
134- skip_reason = get_openvino_skip_reason (
135- item ,
136- openvino_supported_paths ,
137- auto_skip_training ,
163+ continue
164+
165+ parts = test_path .replace ("\\ " , "/" ).split ("/" )
166+ try :
167+ keras_hub_idx = parts .index ("keras_hub" )
168+ relative_test_path = "/" .join (parts [keras_hub_idx :])
169+ except ValueError :
170+ relative_test_path = test_path
171+
172+ is_whitelisted = any (
173+ relative_test_path == supported_path
174+ or relative_test_path .startswith (supported_path + "/" )
175+ for supported_path in openvino_supported_paths
138176 )
139- if skip_reason :
177+
178+ if not is_whitelisted :
140179 item .add_marker (
141- pytest .mark .skipif (True , reason = f"OpenVINO: { skip_reason } " )
180+ pytest .mark .skipif (
181+ True , reason = "OpenVINO: File/directory not in whitelist"
182+ )
142183 )
143184
144185
0 commit comments