Skip to content

Commit 1dc53a2

Browse files
committed
Use friend not to expose tensor's type/place
1 parent a89c7ff commit 1dc53a2

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

paddle/framework/tensor.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ limitations under the License. */
2424
#include "paddle/platform/place.h"
2525

2626
namespace paddle {
27+
namespace pybind {
28+
namespace details { // forward declare
29+
template <bool less, size_t i, typename... args>
30+
struct CastToPyBufferImpl;
31+
} // namespace details
32+
} // namespace pybind
2733
namespace framework {
2834

2935
class Tensor {
@@ -128,10 +134,6 @@ class Tensor {
128134

129135
DDim dims() const { return dims_; }
130136

131-
platform::Place place() const { return holder_->place(); }
132-
133-
std::type_index type() const { return holder_->type(); }
134-
135137
private:
136138
// Placeholder hides type T, so it doesn't appear as a template
137139
// parameter of Variable.
@@ -186,7 +188,9 @@ class Tensor {
186188
DDim dims_;
187189
size_t numel_; // cache of `product(dims_)`
188190
size_t offset_; // marks the begin of tensor data area.
189-
}; // namespace framework
191+
template <bool less, size_t i, typename... args>
192+
friend struct paddle::pybind::details::CastToPyBufferImpl;
193+
}; // namespace framework
190194

191195
} // namespace framework
192196
} // namespace paddle

paddle/pybind/pybind.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License. */
1515
#include <Python.h>
1616
#include <paddle/framework/op_registry.h>
1717
#include <paddle/framework/scope.h>
18-
#include <paddle/pybind/tensor.h>
18+
#include <paddle/pybind/tensor_bind.h>
1919
#include <pybind11/numpy.h>
2020
#include <pybind11/pybind11.h>
2121
#include <pybind11/stl.h>
@@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) {
3232

3333
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol())
3434
.def_buffer([](pd::Tensor& self) -> py::buffer_info {
35-
PADDLE_ENFORCE(paddle::platform::is_cpu_place(self.place()),
36-
"Only CPU tensor can cast to numpy array");
3735
return paddle::pybind::CastToPyBuffer(self);
3836
})
3937
.def("get_dims",
Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ template <size_t I, typename... ARGS>
4040
struct CastToPyBufferImpl<true, I, ARGS...> {
4141
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
4242
py::buffer_info operator()(framework::Tensor &tensor) {
43-
if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) {
43+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
44+
"Only CPU tensor can cast to numpy array");
45+
46+
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
4447
auto dim_vec = framework::vectorize(tensor.dims());
4548
std::vector<size_t> dims_outside;
4649
std::vector<size_t> strides;
@@ -54,12 +57,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
5457
prod *= dims_outside[i - 1];
5558
}
5659

57-
return py::buffer_info(tensor.mutable_data<CUR_TYPE>(tensor.place()),
58-
sizeof(CUR_TYPE),
59-
py::format_descriptor<CUR_TYPE>::format(),
60-
(size_t)framework::arity(tensor.dims()),
61-
dims_outside,
62-
strides);
60+
return py::buffer_info(
61+
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
62+
sizeof(CUR_TYPE),
63+
py::format_descriptor<CUR_TYPE>::format(),
64+
(size_t)framework::arity(tensor.dims()),
65+
dims_outside,
66+
strides);
6367
} else {
6468
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
6569
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);

0 commit comments

Comments
 (0)