From 8ae55a214873893cf2cf4a7998ea8cd494a96513 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 28 Jun 2024 13:58:06 +0400 Subject: [PATCH 1/3] [JAX FE]Correct creation of numpy dtypes Signed-off-by: Kazantsev, Roman --- .../python/src/openvino/frontend/jax/utils.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/jax/utils.py b/src/bindings/python/src/openvino/frontend/jax/utils.py index 7b860febed3645..64b57cfa40b2b0 100644 --- a/src/bindings/python/src/openvino/frontend/jax/utils.py +++ b/src/bindings/python/src/openvino/frontend/jax/utils.py @@ -11,17 +11,17 @@ from openvino.runtime import op, Type as OVType, Shape, OVAny numpy_to_ov_type_map = { - np.dtypes.Float32DType: OVType.f32, - np.dtypes.BoolDType: OVType.boolean, + np.dtype(np.float32): OVType.f32, + np.dtype(bool): OVType.boolean, jax.dtypes.bfloat16: OVType.bf16, # TODO: check this - np.dtypes.Float16DType: OVType.f16, - np.dtypes.Float32DType: OVType.f32, - np.dtypes.Float64DType: OVType.f64, - np.dtypes.UInt8DType: OVType.u8, - np.dtypes.Int8DType: OVType.i8, - np.dtypes.Int16DType: OVType.i16, - np.dtypes.Int32DType: OVType.i32, - np.dtypes.Int64DType: OVType.i64, + np.dtype(np.float16): OVType.f16, + np.dtype(np.float32): OVType.f32, + np.dtype(np.float64): OVType.f64, + np.dtype(np.uint8): OVType.u8, + np.dtype(np.int8): OVType.i8, + np.dtype(np.int16): OVType.i16, + np.dtype(np.int32): OVType.i32, + np.dtype(np.int64): OVType.i64, } jax_to_ov_type_map = { From 20c955824d119952cf7dc4466530acafcee5bcf5 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 28 Jun 2024 14:19:21 +0400 Subject: [PATCH 2/3] Simplify logic of checking dtypes Signed-off-by: Kazantsev, Roman --- .../python/src/openvino/frontend/jax/utils.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/jax/utils.py b/src/bindings/python/src/openvino/frontend/jax/utils.py index 64b57cfa40b2b0..09fa8487e523c7 100644 --- a/src/bindings/python/src/openvino/frontend/jax/utils.py +++ b/src/bindings/python/src/openvino/frontend/jax/utils.py @@ -11,17 +11,17 @@ from openvino.runtime import op, Type as OVType, Shape, OVAny numpy_to_ov_type_map = { - np.dtype(np.float32): OVType.f32, - np.dtype(bool): OVType.boolean, + np.float32: OVType.f32, + bool: OVType.boolean, jax.dtypes.bfloat16: OVType.bf16, # TODO: check this - np.dtype(np.float16): OVType.f16, - np.dtype(np.float32): OVType.f32, - np.dtype(np.float64): OVType.f64, - np.dtype(np.uint8): OVType.u8, - np.dtype(np.int8): OVType.i8, - np.dtype(np.int16): OVType.i16, - np.dtype(np.int32): OVType.i32, - np.dtype(np.int64): OVType.i64, + np.float16: OVType.f16, + np.float32: OVType.f32, + np.float64: OVType.f64, + np.uint8: OVType.u8, + np.int8: OVType.i8, + np.int16: OVType.i16, + np.int32: OVType.i32, + np.int64: OVType.i64, } jax_to_ov_type_map = { @@ -74,7 +74,7 @@ def get_ov_type_for_value(value): if value.aval.dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[value.aval.dtype]) for k, v in numpy_to_ov_type_map.items(): - if isinstance(value.aval.dtype, k): + if value.aval.dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(value.aval.dtype, k): @@ -88,7 +88,7 @@ def get_ov_type_from_jax_type(dtype): if dtype in jax_to_ov_type_map: return OVAny(jax_to_ov_type_map[dtype]) for k, v in numpy_to_ov_type_map.items(): - if isinstance(dtype, k): + if dtype == k: return OVAny(v) for k, v in basic_to_ov_type_map.items(): if isinstance(dtype, k): From 0a12ceb863c4ff975186d5189030efca79c251b7 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 28 Jun 2024 15:36:05 +0400 Subject: [PATCH 3/3] Correct labeling for JAX FE sources Signed-off-by: Kazantsev, Roman --- .github/labeler.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index eea70d31684e4c..64a8661cf1e2e8 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -141,6 +141,7 @@ 'category: TF FE': - 'src/frontends/tensorflow/**/*' - 'src/frontends/tensorflow_common/**/*' +- 'src/bindings/python/src/openvino/frontend/tensorflow/**/*' - 'tests/layer_tests/tensorflow_tests/**/*' - 'tests/layer_tests/tensorflow2_keras_tests/**/*' - 'tests/layer_tests/jax_tests/**/*' @@ -163,6 +164,7 @@ 'category: JAX FE': - 'src/frontends/jax/**/*' +- 'src/bindings/python/src/openvino/frontend/jax/**/*' - 'tests/layer_tests/jax_tests/**/*' 'category: tools':