Skip to content

Commit 93e334a

Browse files
fix_PyObject_ToInt32 (PaddlePaddle#76419)
1 parent 0e6f7fb commit 93e334a

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

paddle/fluid/pybind/op_function_common.cc

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -98,36 +98,40 @@ bool PyObject_CheckLong(PyObject* obj) {
9898
}
9999

100100
int32_t PyObject_ToInt32(PyObject* obj) {
101-
int64_t res = 0;
101+
int32_t res = 0;
102102
if ((PyLong_Check(obj) && !PyBool_Check(obj)) || // NOLINT
103103
PyObject_CheckVarType(obj) || // NOLINT
104104
PyObject_CheckDataType(obj) || // NOLINT
105105
(PyObject_CheckTensor(obj) &&
106106
reinterpret_cast<TensorObject*>(obj)->tensor.numel() == 1)) {
107-
res = PyLong_AsLongLong(obj);
108-
} else {
109-
std::string type_name =
110-
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
111-
if (type_name.find("numpy.int") != std::string::npos) {
112-
auto num_obj = PyNumber_Long(obj);
113-
res = PyLong_AsLongLong(num_obj);
114-
Py_DECREF(num_obj);
115-
} else {
116-
PADDLE_THROW(
117-
common::errors::InvalidType("Cannot convert %s to int32", type_name));
107+
res = static_cast<int32_t>(PyLong_AsLong(obj));
108+
if (res == -1 && PyErr_Occurred()) {
109+
PyErr_Clear();
110+
PADDLE_THROW(common::errors::OutOfRange(
111+
"Integer value exceeds int32 range [%d, %d]",
112+
std::numeric_limits<int32_t>::min(),
113+
std::numeric_limits<int32_t>::max()));
118114
}
115+
return res;
119116
}
120-
121-
if (res > std::numeric_limits<int32_t>::max() ||
122-
res < std::numeric_limits<int32_t>::min()) {
123-
PADDLE_THROW(common::errors::OutOfRange(
124-
"Integer value %ld exceeds int32 range [%d, %d]",
125-
res,
126-
std::numeric_limits<int32_t>::min(),
127-
std::numeric_limits<int32_t>::max()));
117+
std::string type_name =
118+
std::string(reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name);
119+
if (type_name.find("numpy.int") != std::string::npos) {
120+
auto num_obj = PyNumber_Long(obj);
121+
res = static_cast<int32_t>(PyLong_AsLong(num_obj));
122+
if (res == -1 && PyErr_Occurred()) {
123+
PyErr_Clear();
124+
PADDLE_THROW(common::errors::OutOfRange(
125+
"Integer value exceeds int32 range [%d, %d]",
126+
std::numeric_limits<int32_t>::min(),
127+
std::numeric_limits<int32_t>::max()));
128+
}
129+
Py_DECREF(num_obj);
130+
} else {
131+
PADDLE_THROW(
132+
common::errors::InvalidType("Cannot convert %s to long", type_name));
128133
}
129-
130-
return static_cast<int32_t>(res);
134+
return res;
131135
}
132136

133137
uint32_t PyObject_ToUInt32(PyObject* obj) {

0 commit comments

Comments
 (0)