Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/paddle/image/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs):
def process(settings, file_list):
for i in xrange(1024):
img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten()
lab = random.randint(0, settings.num_class)
lab = random.randint(0, settings.num_class - 1)
yield img.astype('float32'), int(lab)
50 changes: 50 additions & 0 deletions benchmark/paddle/image/run.mkldnn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
set -e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件名从run.mkldnn.sh改成run_mkldnn.sh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


unset OMP_NUM_THREADS MKL_NUM_THREADS
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"

function train() {
topology=$1
bs=$2
thread=1
if [ $3 ]; then
thread=$3
fi
if [ $thread -eq 1 ]; then
use_mkldnn=1
log="logs/${topology}-mkldnn-${bs}.log"
else
use_mkldnn=0
log="logs/${topology}-${thread}mklml-${bs}.log"
fi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14-20行:thread=1的时候,用mkldnn?否则不用?
那看41-50行,都是用了mkldnn。
trainer_count和use_mkldnn都做为一个参数,在41-50行的时候传进去会比较清楚。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为当使用MKLDNN的时候只需要一个trainer,MKLDNN会自己使用多线程进行计算,所以才是现在看到的这样。不过我会改一个把接口都留好的,也方便以后大家测试。

args="batch_size=${bs}"
config="${topology}.py"
paddle train --job=time \
--config=$config \
--use_mkldnn=$use_mkldnn \
--use_gpu=False \
--trainer_count=$thread \
--log_period=10 \
--test_period=100 \
--config_args=$args \
2>&1 | tee ${log}
}

if [ ! -d "train.list" ]; then
echo " " > train.list
fi
if [ ! -d "logs" ]; then
mkdir logs
fi

#========= mkldnn =========#
# vgg
train vgg 64
train vgg 128
train vgg 256

#========== mklml ===========#
train vgg 64 16
train vgg 128 16
train vgg 256 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

41行和47行的注释不明白:一个是用mkldnn,一个是没用?
另外,最后一个layer_num的参数,既然传了16进去,前面的也传19进去吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里比较的是,在WITH_MKLDNNWITH_MKLML都打开的前提下,使用MKLDNN与不使用MKLDNN的区别。

最后一个参数不是layer的层数,是trainer_count的个数。mkldnn默认用一个,所以没写。

103 changes: 103 additions & 0 deletions benchmark/paddle/image/vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python
from paddle.trainer_config_helpers import *

height = 224
width = 224
num_class = 1000
batch_size = get_config_arg('batch_size', int, 64)
layer_num = get_config_arg('layer_num', int, 19)

args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
define_py_data_sources2(
"train.list", None, module="provider", obj="process", args=args)

settings(
batch_size=batch_size,
learning_rate=0.01 / batch_size,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * batch_size))

img = data_layer(name='image', size=height * width * 3)


def vgg_network(vgg_num=3):
tmp = img_conv_group(
input=img,
num_channels=3,
conv_padding=1,
conv_num_filter=[64, 64],
conv_filter_size=3,
conv_act=ReluActivation(),
pool_size=2,
pool_stride=2,
pool_type=MaxPooling())

tmp = img_conv_group(
input=tmp,
conv_num_filter=[128, 128],
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)

channels = []
for i in range(vgg_num):
channels.append(256)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
channels = []
for i in range(vgg_num):
channels.append(512)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)
tmp = img_conv_group(
input=tmp,
conv_num_filter=channels,
conv_padding=1,
conv_filter_size=3,
conv_act=ReluActivation(),
pool_stride=2,
pool_type=MaxPooling(),
pool_size=2)

tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))

tmp = fc_layer(
input=tmp,
size=4096,
act=ReluActivation(),
layer_attr=ExtraAttr(drop_rate=0.5))

return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())


if layer_num == 16:
vgg = vgg_network(3)
elif layer_num == 19:
vgg = vgg_network(4)
else:
print("Wrong layer number.")

lab = data_layer('label', num_class)
loss = cross_entropy(input=vgg, label=lab)
outputs(loss)
4 changes: 4 additions & 0 deletions cmake/util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME)
target_link_libraries(${TARGET_NAME} log)
endif(ANDROID)

if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR)
target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed")
endif()

add_dependencies(${TARGET_NAME} ${external_project_dependencies})
endfunction()

Expand Down
3 changes: 2 additions & 1 deletion paddle/gserver/activations/MKLDNNActivation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
if (cnt_ == act.value->getElementCnt()) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
cnt_ = act.value->getElementCnt();
stream_.reset(new MKLDNNStream());
auto eng = CPUEngine::Instance().getEngine();
Expand All @@ -110,7 +111,6 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
float alpha = getAlpha();
float beta = getBeta();

/// forward
pipelineFwd_.clear();
val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
if (val_ == nullptr) {
Expand Down Expand Up @@ -152,6 +152,7 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation {
if (!needResetBwd_) {
return;
}
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
needResetBwd_ = false;
mkldnn::algorithm algo = getAlgo(this->getName());
float alpha = getBwdAlpha();
Expand Down
31 changes: 20 additions & 11 deletions paddle/gserver/layers/MKLDNNConvLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap,

// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
Expand Down Expand Up @@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue(
// create buffer and reorder if input value do not match
cpuInVal_ = nullptr;
cvtInVal_ = nullptr;
if (inputIsOnlyMKLDNN()) {
MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK(dnnIn) << "Input should be MKLDNNMatrix";
if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) {
CHECK_EQ(dnnIn->getFormat(), format::nc);

MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr);
if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
if (dnnIn) {
if (dnnIn->getFormat() == format::nc) {
CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format";
// create a new one with nchw format and same data
memory::dims inDims = memory::dims{bs_, ic_, 1, 1};
dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc());
}
in = dnnIn;
if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
in = dnnIn;
return;
}
cpuInVal_ = dnnIn;
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in);
CHECK(cvtInVal_) << "should not be emptry";
} else {
const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE);
memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_};
cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_);
cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) {
// create new mkldnn matrix
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
Expand Down Expand Up @@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData(
} else {
wgtValBwdData_ = wgtVal_;
}
VLOG(MKLDNN_FMTS) << "weight value format for backward data"
VLOG(MKLDNN_FMTS) << "weight value format for backward data: "
<< wgtValBwdData_->getFormat();
}

Expand Down
11 changes: 9 additions & 2 deletions paddle/gserver/layers/MKLDNNFcLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,

// create biases
if (biasParameter_.get() != NULL) {
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
Expand Down Expand Up @@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) {

void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias) {
format wgtFmt = format::oihw;
if (inVal_->getFormat() == format::nChw8c) {
wgtFmt = format::oIhw8i;
} else if (inVal_->getFormat() == format::nChw16c) {
wgtFmt = format::oIhw16i;
}
wgt = MKLDNNMatrix::create(
weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_);
weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_);
wgt->downSpatial();
VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat();

bias = (biases_ && biases_->getW())
? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_)
Expand Down
2 changes: 2 additions & 0 deletions paddle/gserver/layers/MKLDNNLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class MKLDNNLayer : public Layer {
copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt();
if (inputElemenCnt_ != elemenCnt) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize
inputElemenCnt_ = elemenCnt;
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
Expand Down Expand Up @@ -142,6 +143,7 @@ class MKLDNNLayer : public Layer {

void backward(const UpdateCallback& callback) override {
if (needResetBwd_) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
Expand Down
13 changes: 13 additions & 0 deletions paddle/trainer/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets
--config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)

################ test_CompareMKLDNNandCPU ######################
if(WITH_MKLDNN)
add_unittest_without_exec(test_CompareMKLDNNandCPU
test_CompareTwoNets.cpp)
add_test(NAME test_CompareMKLDNNandCPU
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/
${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU
--config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True
--config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False
--use_gpu=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

49行的use_gpu可以去掉吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是因为目前使用对比的时候是不希望使用gpu的。

WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
endif()

############### test_CompareTwoOpts ###################
add_unittest_without_exec(test_CompareTwoOpts
test_CompareTwoOpts.cpp)
Expand Down
63 changes: 63 additions & 0 deletions paddle/trainer/tests/sample_trainer_config_simple_net.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *

################################### Data Configuration ###################################
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
################################### Algorithm Configuration ###################################
settings(batch_size = 1000,
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
################################### Network Configuration ###################################
data = data_layer(name ="input", size=784)

tmp = img_conv_layer(input=data,
num_channels=1,
filter_size=3,
num_filters=32,
padding=1,
shared_biases=True,
act=ReluActivation())

tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=AvgPooling())

tmp = img_conv_layer(input=tmp,
filter_size=3,
num_filters=64,
padding=1,
shared_biases=True,
act=ReluActivation())

tmp = img_pool_layer(input=tmp,
pool_size=3,
stride=2,
padding=1,
pool_type=MaxPooling())

tmp = fc_layer(input=tmp, size=64,
bias_attr=True,
act=ReluActivation())

output = fc_layer(input=tmp, size=10,
bias_attr=True,
act=SoftmaxActivation())

lbl = data_layer(name ="label", size=10)

cost = classification_cost(input=output, label=lbl)
outputs(cost)
Loading