Skip to content

Commit 62a1f1b

Browse files
authored
[JAX FE] Correct creation of numpy dtypes (#25278)
**Details:** Correct creation of numpy dtypes dictionary. `dtypes` looks not accessible for some numpy versions. And it leads to error: AttributeError: module 'numpy' has no attribute 'dtypes'. Example: https://github.com/openvinotoolkit/openvino/actions/runs/9709509944/job/26799144381#step:25:12135 So I simplify creation of types mapping and checking. **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman <[email protected]>
1 parent 64c7d92 commit 62a1f1b

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

.github/labeler.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
'category: TF FE':
142142
- 'src/frontends/tensorflow/**/*'
143143
- 'src/frontends/tensorflow_common/**/*'
144+
- 'src/bindings/python/src/openvino/frontend/tensorflow/**/*'
144145
- 'tests/layer_tests/tensorflow_tests/**/*'
145146
- 'tests/layer_tests/tensorflow2_keras_tests/**/*'
146147
- 'tests/layer_tests/jax_tests/**/*'
@@ -163,6 +164,7 @@
163164

164165
'category: JAX FE':
165166
- 'src/frontends/jax/**/*'
167+
- 'src/bindings/python/src/openvino/frontend/jax/**/*'
166168
- 'tests/layer_tests/jax_tests/**/*'
167169

168170
'category: tools':

src/bindings/python/src/openvino/frontend/jax/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
from openvino.runtime import op, Type as OVType, Shape, OVAny
1212

1313
numpy_to_ov_type_map = {
14-
np.dtypes.Float32DType: OVType.f32,
15-
np.dtypes.BoolDType: OVType.boolean,
14+
np.float32: OVType.f32,
15+
bool: OVType.boolean,
1616
jax.dtypes.bfloat16: OVType.bf16, # TODO: check this
17-
np.dtypes.Float16DType: OVType.f16,
18-
np.dtypes.Float32DType: OVType.f32,
19-
np.dtypes.Float64DType: OVType.f64,
20-
np.dtypes.UInt8DType: OVType.u8,
21-
np.dtypes.Int8DType: OVType.i8,
22-
np.dtypes.Int16DType: OVType.i16,
23-
np.dtypes.Int32DType: OVType.i32,
24-
np.dtypes.Int64DType: OVType.i64,
17+
np.float16: OVType.f16,
18+
np.float32: OVType.f32,
19+
np.float64: OVType.f64,
20+
np.uint8: OVType.u8,
21+
np.int8: OVType.i8,
22+
np.int16: OVType.i16,
23+
np.int32: OVType.i32,
24+
np.int64: OVType.i64,
2525
}
2626

2727
jax_to_ov_type_map = {
@@ -74,7 +74,7 @@ def get_ov_type_for_value(value):
7474
if value.aval.dtype in jax_to_ov_type_map:
7575
return OVAny(jax_to_ov_type_map[value.aval.dtype])
7676
for k, v in numpy_to_ov_type_map.items():
77-
if isinstance(value.aval.dtype, k):
77+
if value.aval.dtype == k:
7878
return OVAny(v)
7979
for k, v in basic_to_ov_type_map.items():
8080
if isinstance(value.aval.dtype, k):
@@ -88,7 +88,7 @@ def get_ov_type_from_jax_type(dtype):
8888
if dtype in jax_to_ov_type_map:
8989
return OVAny(jax_to_ov_type_map[dtype])
9090
for k, v in numpy_to_ov_type_map.items():
91-
if isinstance(dtype, k):
91+
if dtype == k:
9292
return OVAny(v)
9393
for k, v in basic_to_ov_type_map.items():
9494
if isinstance(dtype, k):

0 commit comments

Comments
 (0)