Skip to content

Commit 308bdc8

Browse files
committed
fix SymbolicInt
1 parent c193070 commit 308bdc8

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

paddle/fluid/pybind/sot/guards.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ bool ShapeMatchGuard::check(PyObject* value) {
141141
return false;
142142
}
143143
for (size_t i = 0; i < shape.size(); ++i) {
144-
if (shape[i] <= 0 || shape[i] != expected_[i]) {
144+
if (expected_[i] == -1 && shape[i] >= 1) continue;
145+
if (shape[i] != expected_[i]) {
145146
return false;
146147
}
147148
}

paddle/fluid/pybind/sot/guards.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ class ShapeMatchGuard : public GuardBase {
159159
explicit ShapeMatchGuard(const std::vector<py::object>& shape) {
160160
expected_.reserve(shape.size());
161161
for (const auto& s : shape) {
162-
expected_.push_back(s.cast<int64_t>());
162+
if (py::isinstance<py::int_>(s)) {
163+
expected_.push_back(s.cast<int64_t>());
164+
} else {
165+
expected_.push_back(-1);
166+
}
163167
}
164168
}
165169

0 commit comments

Comments
 (0)