Skip to content

Commit 1480720

Browse files
authored
Merge pull request #2891 from Canpio/dev_enable_tensor_test
Enable tensor test
2 parents df707d0 + 1cd14f6 commit 1480720

File tree

2 files changed

+46
-27
lines changed

2 files changed

+46
-27
lines changed

paddle/framework/tensor.h

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ class Tensor {
2929
public:
3030
Tensor() : numel_(0), offset_(0) {}
3131

32-
Tensor& operator=(const Tensor& src) = delete;
33-
3432
template <typename T>
3533
const T* data() const {
3634
CheckDims<T>();
@@ -39,21 +37,37 @@ class Tensor {
3937
}
4038

4139
template <typename T>
42-
T* mutable_data(DDim dims, paddle::platform::Place place) {
40+
T* mutable_data(DDim dims, platform::Place place) {
4341
set_dims(dims);
4442
return mutable_data<T>(place);
4543
}
4644

4745
template <typename T>
48-
T* mutable_data(paddle::platform::Place place) {
46+
T* mutable_data(platform::Place place) {
4947
PADDLE_ENFORCE(numel_ > 0,
5048
"Tensor::numel_ must be larger than zero to call "
5149
"Tensor::mutable_data. Call Tensor::set_dim first.");
5250
if (holder_ == nullptr ||
5351
!(holder_->place() ==
5452
place) /* some versions of boost::variant don't have operator!= */
5553
|| holder_->size() < numel_ * sizeof(T) + offset_) {
56-
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
54+
#ifdef __CUDACC__
55+
switch (place.which()) {
56+
case 0:
57+
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
58+
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
59+
break;
60+
61+
case 1:
62+
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
63+
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
64+
break;
65+
}
66+
#else
67+
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
68+
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
69+
#endif
70+
5771
offset_ = 0;
5872
}
5973
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
@@ -69,7 +83,7 @@ class Tensor {
6983
}
7084

7185
template <typename T>
72-
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) {
86+
void CopyFrom(const Tensor& src, platform::Place dst_place) {
7387
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
7488
platform::is_cpu_place(dst_place),
7589
"Tensor::CopyFrom only support CPU now.");
@@ -119,38 +133,37 @@ class Tensor {
119133
struct Placeholder {
120134
virtual ~Placeholder() {}
121135
virtual void* ptr() const = 0;
122-
virtual paddle::platform::Place place() const = 0;
136+
virtual platform::Place place() const = 0;
123137
virtual size_t size() const = 0;
124138
};
125139

126-
template <typename T>
140+
template <typename T, typename PlaceType>
127141
struct PlaceholderImpl : public Placeholder {
128142
private:
143+
template <typename PType>
129144
class Deleter {
130145
public:
131-
Deleter(platform::Place place) : place_(place) {}
132-
void operator()(T* ptr) {
133-
paddle::memory::Free(place_, static_cast<void*>(ptr));
134-
}
146+
Deleter(PType place) : place_(place) {}
147+
void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
135148

136149
private:
137-
paddle::platform::Place place_;
150+
PType place_;
138151
};
139152

140153
public:
141-
PlaceholderImpl(paddle::platform::Place place, size_t size)
142-
: ptr_(static_cast<T*>(paddle::memory::Alloc(place, size)),
143-
Deleter(place)),
154+
PlaceholderImpl(PlaceType place, size_t size)
155+
: ptr_(static_cast<T*>(memory::Alloc(place, size)),
156+
Deleter<PlaceType>(place)),
144157
place_(place),
145158
size_(size) {}
146159

147160
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
148161
virtual size_t size() const { return size_; }
149-
virtual paddle::platform::Place place() const { return place_; }
162+
virtual platform::Place place() const { return place_; }
150163

151-
std::unique_ptr<T, Deleter> ptr_;
152-
paddle::platform::Place place_; // record the place of ptr_.
153-
size_t size_; // size of the memory block.
164+
std::unique_ptr<T, Deleter<PlaceType>> ptr_;
165+
platform::Place place_; // record the place of ptr_.
166+
size_t size_; // size of the memory block.
154167
};
155168

156169
template <typename T>
@@ -166,7 +179,7 @@ class Tensor {
166179
DDim dims_;
167180
size_t numel_; // cache of `product(dims_)`
168181
size_t offset_; // marks the begin of tensor data area.
169-
};
182+
}; // namespace framework
170183

171184
} // namespace framework
172185
} // namespace paddle

paddle/framework/tensor_test.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ TEST(Tensor, DataAssert) {
4747

4848
/* following tests are not available at present
4949
because Memory::Alloc() and Memory::Free() have not been ready.
50-
50+
*/
5151
TEST(Tensor, MutableData) {
5252
using namespace paddle::framework;
5353
using namespace paddle::platform;
@@ -72,7 +72,7 @@ TEST(Tensor, MutableData) {
7272
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
7373
EXPECT_EQ(p1, p2);
7474
}
75-
75+
#ifdef __CUDACC__
7676
{
7777
Tensor src_tensor;
7878
float* p1 = nullptr;
@@ -94,6 +94,7 @@ TEST(Tensor, MutableData) {
9494
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
9595
EXPECT_EQ(p1, p2);
9696
}
97+
#endif
9798
}
9899

99100
TEST(Tensor, ShareDataFrom) {
@@ -108,9 +109,11 @@ TEST(Tensor, ShareDataFrom) {
108109
dst_tensor.ShareDataFrom<float>(src_tensor);
109110
} catch (EnforceNotMet err) {
110111
caught = true;
111-
std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data
112-
first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
113-
++i) { ASSERT_EQ(what[i], msg[i]);
112+
std::string msg =
113+
"Tenosr holds no memory. Call Tensor::mutable_data first.";
114+
const char* what = err.what();
115+
for (size_t i = 0; i < msg.length(); ++i) {
116+
ASSERT_EQ(what[i], msg[i]);
114117
}
115118
}
116119
ASSERT_TRUE(caught);
@@ -120,13 +123,15 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
120123
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
121124
}
122125

126+
#ifdef __CUDACC__
123127
{
124128
Tensor src_tensor;
125129
Tensor dst_tensor;
126130
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
127131
dst_tensor.ShareDataFrom<int>(src_tensor);
128132
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
129133
}
134+
#endif
130135
}
131136

132137
TEST(Tensor, Slice) {
@@ -155,6 +160,7 @@ TEST(Tensor, Slice) {
155160
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
156161
}
157162

163+
#ifdef __CUDACC__
158164
{
159165
Tensor src_tensor;
160166
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
@@ -176,6 +182,7 @@ TEST(Tensor, Slice) {
176182
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
177183
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
178184
}
185+
#endif
179186
}
180187

181188
TEST(Tensor, CopyFrom) {
@@ -203,4 +210,3 @@ TEST(Tensor, CopyFrom) {
203210
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
204211
}
205212
}
206-
*/

0 commit comments

Comments
 (0)