diff --git a/.github/labeler.yml b/.github/labeler.yml index eea70d31684e4c..64a8661cf1e2e8 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -141,6 +141,7 @@ 'category: TF FE': - 'src/frontends/tensorflow/**/*' - 'src/frontends/tensorflow_common/**/*' +- 'src/bindings/python/src/openvino/frontend/tensorflow/**/*' - 'tests/layer_tests/tensorflow_tests/**/*' - 'tests/layer_tests/tensorflow2_keras_tests/**/*' - 'tests/layer_tests/jax_tests/**/*' @@ -163,6 +164,7 @@ 'category: JAX FE': - 'src/frontends/jax/**/*' +- 'src/bindings/python/src/openvino/frontend/jax/**/*' - 'tests/layer_tests/jax_tests/**/*' 'category: tools': diff --git a/src/bindings/python/src/openvino/frontend/jax/utils.py b/src/bindings/python/src/openvino/frontend/jax/utils.py index 7b860febed3645..09fa8487e523c7 100644 --- a/src/bindings/python/src/openvino/frontend/jax/utils.py +++ b/src/bindings/python/src/openvino/frontend/jax/utils.py @@ -11,17 +11,17 @@ from openvino.runtime import op, Type as OVType, Shape, OVAny numpy_to_ov_type_map = { - np.dtypes.Float32DType: OVType.f32, - np.dtypes.BoolDType: OVType.boolean, + np.float32: OVType.f32, + bool: OVType.boolean, jax.dtypes.bfloat16: OVType.bf16, # TODO: check this - np.dtypes.Float16DType: OVType.f16, - np.dtypes.Float32DType: OVType.f32, - np.dtypes.Float64DType: OVType.f64, - np.dtypes.UInt8DType: OVType.u8, - np.dtypes.Int8DType: OVType.i8, - np.dtypes.Int16DType: OVType.i16, - np.dtypes.Int32DType: OVType.i32, - np.dtypes.Int64DType: OVType.i64, + np.float16: OVType.f16, + np.float32: OVType.f32, + np.float64: OVType.f64, + np.uint8: OVType.u8, + np.int8: OVType.i8, + np.int16: OVType.i16, + np.int32: OVType.i32, + np.int64: OVType.i64, } jax_to_ov_type_map = { @@ -74,7 +74,7 @@ def get_ov_type_for_value(value): if value.aval.dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[value.aval.dtype]) for k, v in numpy_to_ov_type_map.items(): - if isinstance(value.aval.dtype, k): + if value.aval.dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(value.aval.dtype, k): @@ -88,7 +88,7 @@ def get_ov_type_from_jax_type(dtype): if dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[dtype]) for k, v in numpy_to_ov_type_map.items(): - if isinstance(dtype, k): + if dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(dtype, k):