@@ -29,8 +29,6 @@ class Tensor {
2929 public:
3030 Tensor () : numel_(0 ), offset_(0 ) {}
3131
32- Tensor& operator =(const Tensor& src) = delete ;
33-
3432 template <typename T>
3533 const T* data () const {
3634 CheckDims<T>();
@@ -39,21 +37,37 @@ class Tensor {
3937 }
4038
4139 template <typename T>
42- T* mutable_data (DDim dims, paddle:: platform::Place place) {
40+ T* mutable_data (DDim dims, platform::Place place) {
4341 set_dims (dims);
4442 return mutable_data<T>(place);
4543 }
4644
4745 template <typename T>
48- T* mutable_data (paddle:: platform::Place place) {
46+ T* mutable_data (platform::Place place) {
4947 PADDLE_ENFORCE (numel_ > 0 ,
5048 " Tensor::numel_ must be larger than zero to call "
5149 " Tensor::mutable_data. Call Tensor::set_dim first." );
5250 if (holder_ == nullptr ||
5351 !(holder_->place () ==
5452 place) /* some versions of boost::variant don't have operator!= */
5553 || holder_->size () < numel_ * sizeof (T) + offset_) {
56- holder_.reset (new PlaceholderImpl<T>(place, numel_ * sizeof (T)));
54+ #ifdef __CUDACC__
55+ switch (place.which ()) {
56+ case 0 :
57+ holder_.reset (new PlaceholderImpl<T, platform::GPUPlace>(
58+ boost::get<platform::GPUPlace>(place), numel_ * sizeof (T)));
59+ break ;
60+
61+ case 1 :
62+ holder_.reset (new PlaceholderImpl<T, platform::CPUPlace>(
63+ boost::get<platform::CPUPlace>(place), numel_ * sizeof (T)));
64+ break ;
65+ }
66+ #else
67+ holder_.reset (new PlaceholderImpl<T, platform::CPUPlace>(
68+ boost::get<platform::CPUPlace>(place), numel_ * sizeof (T)));
69+ #endif
70+
5771 offset_ = 0 ;
5872 }
5973 return reinterpret_cast <T*>(reinterpret_cast <uintptr_t >(holder_->ptr ()) +
@@ -69,7 +83,7 @@ class Tensor {
6983 }
7084
7185 template <typename T>
72- void CopyFrom (const Tensor& src, paddle:: platform::Place dst_place) {
86+ void CopyFrom (const Tensor& src, platform::Place dst_place) {
7387 PADDLE_ENFORCE (platform::is_cpu_place (src.holder_ ->place ()) &&
7488 platform::is_cpu_place (dst_place),
7589 " Tensor::CopyFrom only support CPU now." );
@@ -119,38 +133,37 @@ class Tensor {
119133 struct Placeholder {
120134 virtual ~Placeholder () {}
121135 virtual void * ptr () const = 0;
122- virtual paddle:: platform::Place place () const = 0;
136+ virtual platform::Place place () const = 0;
123137 virtual size_t size () const = 0;
124138 };
125139
126- template <typename T>
140+ template <typename T, typename PlaceType >
127141 struct PlaceholderImpl : public Placeholder {
128142 private:
143+ template <typename PType>
129144 class Deleter {
130145 public:
131- Deleter (platform::Place place) : place_(place) {}
132- void operator ()(T* ptr) {
133- paddle::memory::Free (place_, static_cast <void *>(ptr));
134- }
146+ Deleter (PType place) : place_(place) {}
147+ void operator ()(T* ptr) { memory::Free (place_, static_cast <void *>(ptr)); }
135148
136149 private:
137- paddle::platform::Place place_;
150+ PType place_;
138151 };
139152
140153 public:
141- PlaceholderImpl (paddle::platform::Place place, size_t size)
142- : ptr_(static_cast <T*>(paddle:: memory::Alloc(place, size)),
143- Deleter (place)),
154+ PlaceholderImpl (PlaceType place, size_t size)
155+ : ptr_(static_cast <T*>(memory::Alloc(place, size)),
156+ Deleter<PlaceType> (place)),
144157 place_ (place),
145158 size_(size) {}
146159
147160 virtual void * ptr () const { return static_cast <void *>(ptr_.get ()); }
148161 virtual size_t size () const { return size_; }
149- virtual paddle:: platform::Place place () const { return place_; }
162+ virtual platform::Place place () const { return place_; }
150163
151- std::unique_ptr<T, Deleter> ptr_;
152- paddle:: platform::Place place_; // record the place of ptr_.
153- size_t size_; // size of the memory block.
164+ std::unique_ptr<T, Deleter<PlaceType> > ptr_;
165+ platform::Place place_; // record the place of ptr_.
166+ size_t size_; // size of the memory block.
154167 };
155168
156169 template <typename T>
@@ -166,7 +179,7 @@ class Tensor {
166179 DDim dims_;
167180 size_t numel_; // cache of `product(dims_)`
168181 size_t offset_; // marks the begin of tensor data area.
169- };
182+ }; // namespace framework
170183
171184} // namespace framework
172185} // namespace paddle
0 commit comments