diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 15facbd0a5..33a3c5708b 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -23,6 +23,8 @@ jobs: version: keras-3.8 - backend: jax version: keras-nightly + - backend: openvino + version: keras-nightly runs-on: ubuntu-latest env: KERAS_BACKEND: ${{ matrix.backend }} diff --git a/conftest.py b/conftest.py index a5f40eb789..71c2e4ed9c 100644 --- a/conftest.py +++ b/conftest.py @@ -3,6 +3,22 @@ import keras import pytest +# OpenVINO supported test paths +OPENVINO_SUPPORTED_PATHS = [ + "keras-hub/integration_tests", + "keras_hub/src/models/gemma", + "keras_hub/src/models/gpt2", + "keras_hub/src/models/mistral", + "keras_hub/src/tokenizers", +] + +# OpenVINO specific test skips +OPENVINO_SPECIFIC_SKIPPING_TESTS = { + "test_backbone_basics": "bfloat16 dtype not supported", + "test_score_loss": "Non-implemented roll operation", + "test_causal_lm_basics": "Missing ops and requires trainable backend", +} + def pytest_addoption(parser): parser.addoption( @@ -32,6 +48,15 @@ def pytest_addoption(parser): def pytest_configure(config): + # Monkey-patch training methods for OpenVINO backend + if keras.config.backend() == "openvino": + keras.Model.fit = lambda *args, **kwargs: pytest.skip( + "Model.fit() not supported on OpenVINO backend" + ) + keras.Model.train_on_batch = lambda *args, **kwargs: pytest.skip( + "Model.train_on_batch() not supported on OpenVINO backend" + ) + # Verify that device has GPU and detected by backend if config.getoption("--check_gpu"): found_gpu = False @@ -110,6 +135,34 @@ def pytest_collection_modifyitems(config, items): if "kaggle_key_required" in item.keywords: item.add_marker(kaggle_key_required) + # OpenVINO-specific test skipping + if keras.config.backend() == "openvino": + test_name = item.name.split("[")[0] + + if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS: + item.add_marker( + pytest.mark.skipif( + True, + reason="OpenVINO: " + f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}", + ) + ) + continue + + is_whitelisted = any( + item.nodeid.startswith(supported_path + "/") + or item.nodeid.startswith(supported_path + "::") + or item.nodeid == supported_path + for supported_path in OPENVINO_SUPPORTED_PATHS + ) + + if not is_whitelisted: + item.add_marker( + pytest.mark.skipif( + True, reason="OpenVINO: File/directory not in whitelist" + ) + ) + # Disable traceback filtering for quicker debugging of tests failures. keras.config.disable_traceback_filtering() diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 0e31d2c5a2..a12ac33303 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -132,6 +132,17 @@ def make_generate_function(self): return self.generate_function self.generate_function = self.generate_step + if keras.config.backend() == "openvino": + from keras_hub.src.utils.openvino_utils import ov_infer + + def wrapped_generate_function(inputs, stop_token_ids=None): + # Convert to numpy for OpenVINO backend + inputs = tree.map_structure(ops.array, inputs) + return ov_infer( + self, inputs, stop_token_ids, self.generate_step + ) + + self.generate_function = wrapped_generate_function if keras.config.backend() == "torch": import torch diff --git a/keras_hub/src/samplers/beam_sampler.py b/keras_hub/src/samplers/beam_sampler.py index 26941e9f3a..c2e605b234 100644 --- a/keras_hub/src/samplers/beam_sampler.py +++ b/keras_hub/src/samplers/beam_sampler.py @@ -95,15 +95,15 @@ def unflatten_beams(x): ) log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0)) - def cond(prompt, cache, index, log_probs): + def cond(prompt, cache, index, mask, log_probs): if stop_token_ids is None: - return True + return ops.convert_to_tensor(True, dtype="bool") # Stop if all sequences have produced a *new* stop token. end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) - def body(prompt, cache, index, log_probs): + def body(prompt, cache, index, mask, log_probs): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) vocab_size = ops.shape(logits)[-1] @@ -150,12 +150,12 @@ def gather_beams(x): next_token = next_token[:, None] prompt = ops.slice_update(prompt, [0, index], next_token) # Return the iteration of the loop state. - return (prompt, cache, index + 1, log_probs) + return (prompt, cache, index + 1, mask, log_probs) - prompt, _, _, log_probs = self.run_loop( + prompt, _, _, _, log_probs = self.run_loop( cond=cond, body=body, - loop_vars=(prompt, cache, index, log_probs), + loop_vars=(prompt, cache, index, mask, log_probs), maximum_iterations=(max_length - index), model=model, ) diff --git a/keras_hub/src/samplers/sampler.py b/keras_hub/src/samplers/sampler.py index e3dd2627ee..44c4168375 100644 --- a/keras_hub/src/samplers/sampler.py +++ b/keras_hub/src/samplers/sampler.py @@ -92,16 +92,18 @@ def __call__( # `ops.while_loop` will not accept `None` as a value for `loop_vars`. cache = () if cache is None else cache - def cond(prompt, cache, index): + # OpenVINO requires all parameters to be passed in the body. + # So we pass `mask` as well. + def cond(prompt, cache, index, mask): if stop_token_ids is None: - return True + return ops.convert_to_tensor(True, dtype="bool") # Stop if all sequences have produced a *new* id from # stop_token_ids. end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) - def body(prompt, cache, index): + def body(prompt, cache, index, mask): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) probabilities = self.compute_probabilities(logits) @@ -115,12 +117,12 @@ def body(prompt, cache, index): prompt = ops.slice_update(prompt, [0, index], next_token) # Return the next prompt, cache and incremented index. - return (prompt, cache, index + 1) + return (prompt, cache, index + 1, mask) - prompt, _, _ = self.run_loop( + prompt, _, _, _ = self.run_loop( cond, body, - loop_vars=(prompt, cache, index), + loop_vars=(prompt, cache, index, mask), maximum_iterations=(max_length - index), model=model, ) diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py new file mode 100644 index 0000000000..68570e0d15 --- /dev/null +++ b/keras_hub/src/utils/openvino_utils.py @@ -0,0 +1,141 @@ +from keras import tree + +from keras_hub.src.utils.keras_utils import print_msg + +try: + import openvino as ov + import openvino.opset14 as ov_opset + from openvino import Core +except ImportError: + ov = None + ov_opset = None + Core = None + + +_core = None + + +def get_core(): + """Get or create OpenVINO Core instance. + + Returns: + openvino.Core: OpenVINO Core instance, + or None if OpenVINO not available. + """ + global _core + if _core is None and Core is not None: + _core = Core() + return _core + + +def get_device(): + """Detect and return the best available OpenVINO device. + + Returns: + str: "GPU" if available, otherwise "CPU". + """ + core = get_core() + if core is None: + return "CPU" + return "GPU" if "GPU" in core.available_devices else "CPU" + + +def compile_model(struct_params, struct_outputs, device, model_dtype): + """Compile OpenVINO model with dynamic shapes and precision hints. + + Args: + struct_params: Model parameters structure. + struct_outputs: Model outputs structure. + device: Target device ("GPU" or "CPU"). + model_dtype: Model precision ("f16" or "f32"). + + Returns: + Compiled OpenVINO model ready for inference. + """ + flat_params = tree.flatten(struct_params) + flat_outputs = tree.flatten(struct_outputs) + parameters = [p.output.get_node() for p in flat_params] + results = [ov_opset.result(r.output) for r in flat_outputs] + ov_model = ov.Model(results=results, parameters=parameters) + for ov_input in ov_model.inputs: + rank = ov_input.get_partial_shape().rank.get_length() + ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank)) + ov_model.validate_nodes_and_infer_types() + config = {"INFERENCE_PRECISION_HINT": model_dtype} + core = get_core() + if core is None: + raise RuntimeError("OpenVINO not available") + return core.compile_model(ov_model, device, config) + + +def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton): + """Execute compiled OpenVINO model and return structured outputs. + + Args: + inputs: Input tensors for inference. + struct_outputs: Expected output structure. + compiled_ov_model: Compiled OpenVINO model. + unpack_singleton: Function to unpack singleton outputs. + + Returns: + Structured model outputs matching expected format. + """ + flatten_inputs = tree.flatten(inputs) + raw = compiled_ov_model(flatten_inputs).to_tuple() + packed = tree.pack_sequence_as(struct_outputs, raw) + return unpack_singleton(packed) + + +def ov_infer(model, inputs, stop_token_ids, fn): + """High-level OpenVINO inference with model reuse and compilation. + + This function manages OpenVINO model compilation and caching. It reuses + existing compiled models when possible, or compiles new ones as needed. + Handles device detection and automatic precision selection. + + Args: + model: Keras model with OpenVINO backend support. + inputs: Input tensors for inference. + stop_token_ids: Token IDs that should stop generation. + fn: Function to execute with the parameterized inputs. + + Returns: + Model outputs from OpenVINO inference. + """ + device = get_device() + + # Try to use existing compiled model for the same device + if ( + getattr(model, "ov_compiled_model", None) is not None + and getattr(model, "ov_device", None) is not None + and device == model.ov_device + ): + try: + return get_outputs( + inputs, + model.struct_outputs, + model.ov_compiled_model, + model._unpack_singleton, + ) + except RuntimeError as e: + print_msg( + "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " + "recompiling model and trying again.\n" + str(e) + ) + model.ov_compiled_model = None + model.struct_outputs = None + + # Compile a new model + struct_params = model._parameterize_data(inputs) + model.struct_outputs = fn(struct_params, stop_token_ids) + model.ov_device = device + model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32" + model.ov_compiled_model = compile_model( + struct_params, model.struct_outputs, device, model_dtype + ) + return get_outputs( + inputs, + model.struct_outputs, + model.ov_compiled_model, + model._unpack_singleton, + ) diff --git a/keras_hub/src/utils/openvino_utils_test.py b/keras_hub/src/utils/openvino_utils_test.py new file mode 100644 index 0000000000..8fe8acc830 --- /dev/null +++ b/keras_hub/src/utils/openvino_utils_test.py @@ -0,0 +1,185 @@ +import unittest.mock + +import keras +import numpy as np +import openvino as ov +import pytest +from openvino import Core + +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.openvino_utils import compile_model +from keras_hub.src.utils.openvino_utils import get_device +from keras_hub.src.utils.openvino_utils import get_outputs +from keras_hub.src.utils.openvino_utils import ov_infer + + +@pytest.mark.skipif( + keras.config.backend() != "openvino", + reason="OpenVINO tests only run with OpenVINO backend", +) +class OpenVINOUtilsTest(TestCase): + def setUp(self): + super().setUp() + if ov is None: + self.skipTest("OpenVINO not available") + + def test_get_device_returns_valid_device(self): + device = get_device() + self.assertIn(device, ["GPU", "CPU"]) + + core = Core() + self.assertIn(device, core.available_devices) + + def test_get_device_consistency(self): + device1 = get_device() + device2 = get_device() + self.assertEqual(device1, device2) + + def test_compile_model_basic_and_precision_hints(self): + class _MockParam: + def __init__(self): + self.output = unittest.mock.MagicMock() + self.output.get_node.return_value = unittest.mock.MagicMock() + + class _MockOutput: + def __init__(self): + self.output = unittest.mock.MagicMock() + + with ( + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.ov.Model" + ) as mock_model_class, + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.get_core" + ) as mock_get_core, + ): + mock_model_class.return_value = unittest.mock.MagicMock() + mock_core = unittest.mock.MagicMock() + mock_get_core.return_value = mock_core + mock_core.compile_model.return_value = unittest.mock.MagicMock() + + struct_params = [_MockParam(), _MockParam()] + struct_outputs = [_MockOutput()] + device = "CPU" + + for dtype in ("f32", "f16"): + with self.subTest(dtype=dtype): + result = compile_model( + struct_params, struct_outputs, device, dtype + ) + self.assertIsNotNone(result) + + self.assertEqual(mock_core.compile_model.call_count, 2) + + def test_get_outputs_basic_functionality(self): + class MockResult: + def __init__(self, data): + self.data = data + + def to_tuple(self): + return (self.data,) + + class MockCompiledModel: + def __init__(self): + self.inputs = ["input"] + self.outputs = ["output"] + + def __call__(self, flatten_inputs): + input_data = flatten_inputs[0] + output_data = np.maximum(input_data, 0.0) + return MockResult(output_data) + + class MockOutput: + def get_node(self): + return "mock_relu_node" + + compiled_model = MockCompiledModel() + struct_outputs = [MockOutput()] + + test_input = np.array([[-1.0, 0.0, 1.0]], dtype=np.float32) + inputs = [test_input] + + def mock_unpack_singleton(x): + return x[0] if len(x) == 1 else x + + outputs = get_outputs( + inputs, struct_outputs, compiled_model, mock_unpack_singleton + ) + expected = np.array([[0.0, 0.0, 1.0]], dtype=np.float32) + np.testing.assert_array_almost_equal(outputs, expected) + + def test_ov_infer_model_caching(self): + current_device = get_device() + + class MockModel: + def __init__(self): + self.dtype = "float32" + self.ov_compiled_model = unittest.mock.MagicMock() + self.ov_device = current_device + self.struct_outputs = ["mock_output"] + + def _parameterize_data(self, inputs): + return ["mock_param"] + + def _unpack_singleton(self, x): + return x[0] if len(x) == 1 else x + + def mock_fn(struct_params, stop_token_ids): + return ["mock_output"] + + model = MockModel() + test_input = [np.array([[1.0, 2.0, 3.0]], dtype=np.float32)] + cached_model = model.ov_compiled_model + + with unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.get_outputs" + ) as mock_get_outputs: + mock_get_outputs.return_value = np.array( + [[2.0, 4.0, 6.0]], dtype=np.float32 + ) + result = ov_infer(model, test_input, None, mock_fn) + + self.assertIs(model.ov_compiled_model, cached_model) + self.assertIsNotNone(result) + + def test_ov_infer_dtype_selection(self): + class MockModel: + def __init__(self, dtype): + self.dtype = dtype + self.ov_compiled_model = None + self.ov_device = None + self.struct_outputs = None + + def _parameterize_data(self, inputs): + return ["mock_param"] + + def _unpack_singleton(self, x): + return x[0] if len(x) == 1 else x + + def mock_fn(struct_params, stop_token_ids): + return ["mock_output"] + + test_cases = [ + ("float32", "f32"), + ("float16", "f16"), + ("bfloat16", "f16"), + ] + for model_dtype, expected_ov_dtype in test_cases: + with self.subTest(dtype=model_dtype): + model = MockModel(model_dtype) + test_input = [np.array([[1.0, 2.0]], dtype=np.float32)] + with ( + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.compile_model" + ) as mock_compile, + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.get_outputs" + ) as mock_get_outputs, + ): + mock_compile.return_value = "mock_compiled_model" + mock_get_outputs.return_value = np.array( + [[1.0, 2.0]], dtype=np.float32 + ) + ov_infer(model, test_input, None, mock_fn) + args, kwargs = mock_compile.call_args + self.assertEqual(args[3], expected_ov_dtype) diff --git a/requirements-common.txt b/requirements-common.txt index a98ed71301..a258d1cd85 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,4 +18,5 @@ sentencepiece tensorflow-datasets safetensors pillow +openvino transformers