Skip to content

Commit eecfe90

Browse files
[OpenVINO backend] fix numpy conversions (#21498)
* [OpenVINO backend] fix numpy conversions * fix typo Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 77883ff commit eecfe90

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

keras/src/backend/openvino/core.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,21 @@ def __mod__(self, other):
492492
)
493493
return OpenVINOKerasTensor(ov_opset.mod(first, other).output(0))
494494

495+
def __array__(self, dtype=None):
496+
try:
497+
tensor = cast(self, dtype=dtype) if dtype is not None else self
498+
return convert_to_numpy(tensor)
499+
except Exception as e:
500+
raise RuntimeError(
501+
"An OpenVINOKerasTensor is symbolic: it's a placeholder "
502+
"for a shape and a dtype.\n"
503+
"It doesn't have any actual numerical value.\n"
504+
"You cannot convert it to a NumPy array."
505+
) from e
506+
507+
def numpy(self):
508+
return self.__array__()
509+
495510

496511
def ov_to_keras_type(ov_type):
497512
for _keras_type, _ov_type in OPENVINO_DTYPES.items():
@@ -672,8 +687,10 @@ def convert_to_numpy(x):
672687
ov_model = Model(results=[ov_result], parameters=[])
673688
ov_compiled_model = compile_model(ov_model, get_device())
674689
result = ov_compiled_model({})[0]
675-
except:
676-
raise "`convert_to_numpy` cannot convert to numpy"
690+
except Exception as inner_exception:
691+
raise RuntimeError(
692+
"`convert_to_numpy` failed to convert the tensor."
693+
) from inner_exception
677694
return result
678695

679696

@@ -690,6 +707,7 @@ def shape(x):
690707

691708

692709
def cast(x, dtype):
710+
dtype = standardize_dtype(dtype)
693711
ov_type = OPENVINO_DTYPES[dtype]
694712
x = get_ov_output(x)
695713
return OpenVINOKerasTensor(ov_opset.convert(x, ov_type).output(0))

0 commit comments

Comments
 (0)