Skip to content

Commit 875ef0e

Browse files
update pr to suggested changes
1 parent 06a2a8e commit 875ef0e

File tree

2 files changed

+25
-45
lines changed

2 files changed

+25
-45
lines changed

conftest.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
import keras
44
import pytest
55

6+
# OpenVINO supported test paths
7+
OPENVINO_SUPPORTED_PATHS = [
8+
"keras-hub/integration_tests",
9+
"keras_hub/src/models/gemma",
10+
"keras_hub/src/models/gpt2",
11+
"keras_hub/src/models/mistral",
12+
"keras_hub/src/tokenizers",
13+
]
14+
15+
# OpenVINO specific test skips
16+
OPENVINO_SPECIFIC_SKIPPING_TESTS = {
17+
"test_backbone_basics": "bfloat16 dtype not supported",
18+
"test_score_loss": "Non-implemented roll operation",
19+
"test_causal_lm_basics": "requires trainable backend",
20+
}
21+
622

723
def pytest_addoption(parser):
824
parser.addoption(
@@ -34,15 +50,6 @@ def pytest_addoption(parser):
3450
def pytest_configure(config):
3551
# Monkey-patch training methods for OpenVINO backend
3652
if keras.config.backend() == "openvino":
37-
# Store original methods in case we need to restore them
38-
if not hasattr(keras.Model, "_original_compile"):
39-
keras.Model._original_compile = keras.Model.compile
40-
keras.Model._original_fit = keras.Model.fit
41-
keras.Model._original_train_on_batch = keras.Model.train_on_batch
42-
43-
keras.Model.compile = lambda *args, **kwargs: pytest.skip(
44-
"Model.compile() not supported on OpenVINO backend"
45-
)
4653
keras.Model.fit = lambda *args, **kwargs: pytest.skip(
4754
"Model.fit() not supported on OpenVINO backend"
4855
)
@@ -131,48 +138,22 @@ def pytest_collection_modifyitems(config, items):
131138
# OpenVINO-specific test skipping
132139
if keras.config.backend() == "openvino":
133140
test_name = item.name.split("[")[0]
134-
test_path = str(item.fspath)
135-
136-
# OpenVINO supported test paths
137-
openvino_supported_paths = [
138-
"keras-hub/integration_tests",
139-
"keras_hub/src/models/gemma",
140-
"keras_hub/src/models/gpt2",
141-
"keras_hub/src/models/mistral",
142-
"keras_hub/src/samplers/serialization_test.py",
143-
"keras_hub/src/tests/doc_tests/docstring_test.py",
144-
"keras_hub/src/tokenizers",
145-
"keras_hub/src/utils",
146-
]
147-
148-
# Skip specific problematic test methods
149-
specific_skipping_tests = {
150-
"test_backbone_basics": "Requires trainable backend",
151-
"test_score_loss": "Non-implemented roll operation",
152-
"test_layer_behaviors": "Requires trainable backend",
153-
}
154-
155-
if test_name in specific_skipping_tests:
141+
142+
if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS:
156143
item.add_marker(
157144
pytest.mark.skipif(
158145
True,
159146
reason="OpenVINO: "
160-
f"{specific_skipping_tests[test_name]}",
147+
f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}",
161148
)
162149
)
163150
continue
164151

165-
parts = test_path.replace("\\", "/").split("/")
166-
try:
167-
keras_hub_idx = parts.index("keras_hub")
168-
relative_test_path = "/".join(parts[keras_hub_idx:])
169-
except ValueError:
170-
relative_test_path = test_path
171-
172152
is_whitelisted = any(
173-
relative_test_path == supported_path
174-
or relative_test_path.startswith(supported_path + "/")
175-
for supported_path in openvino_supported_paths
153+
item.nodeid.startswith(supported_path + "/")
154+
or item.nodeid.startswith(supported_path + "::")
155+
or item.nodeid == supported_path
156+
for supported_path in OPENVINO_SUPPORTED_PATHS
176157
)
177158

178159
if not is_whitelisted:

keras_hub/src/utils/openvino_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
from keras_hub.src.utils.keras_utils import print_msg
44

5+
_core = None
6+
57
try:
68
import openvino as ov
79
import openvino.opset14 as ov_opset
810
from openvino import Core
9-
10-
_core = None
1111
except ImportError:
1212
ov = None
1313
ov_opset = None
1414
Core = None
15-
_core = None
1615

1716

1817
def get_core():

0 commit comments

Comments
 (0)