-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Improve pruning module #2354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve pruning module #2354
Changes from 15 commits
5413af8
ca55a24
18435f2
b6eaed0
092828f
1a1c589
23a1480
1566848
a1e1472
97a2fde
997cef2
23b1a27
98e4bb7
4fbec82
5405dc0
6248e56
fc9e3e4
a3ada68
885275e
1a82e7d
15bf6e0
1eab8cc
aaf11fa
a266292
43771ad
fdde4ef
561c456
1d6b859
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,13 @@ limitations under the License. */ | |
|
|
||
| #include "ParameterUpdaterHook.h" | ||
|
|
||
| #include <algorithm> | ||
| #include <atomic> | ||
| #include <fstream> | ||
| #include <mutex> | ||
| #include <thread> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include "paddle/math/Vector.h" | ||
| #include "paddle/parameter/Parameter.h" | ||
|
|
@@ -29,40 +31,21 @@ namespace paddle { | |
|
|
||
| /** | ||
| * The static pruning hook | ||
| * | ||
| * Static means user load a mask map before training started. This map will | ||
| * define which link/weight between neural is disabled. | ||
| * Static means user specific a sparsity_ratio before training start, and the | ||
| * network will prune the parameters based on the sparsity_ratio. More deatils | ||
| * can see https://arxiv.org/pdf/1506.02626.pdf. | ||
| */ | ||
|
|
||
| class StaticPruningHook : public IParameterUpdaterHook { | ||
| public: | ||
| /** | ||
| * The Mask Map Header. | ||
| * The map file started with this header. | ||
| * | ||
| * In Version 0, reset file will be: | ||
| * contains header.size bit, each bit means such weight is enabled or not. | ||
| * if bit is 1, then such weight is enabled. | ||
| * at end, the file will round to byte, and the low bits of end byte will be | ||
| * filled by zero. | ||
| * | ||
| */ | ||
| struct StaticMaskHeader { | ||
| uint32_t version; | ||
| size_t size; | ||
| } __attribute__((__packed__)); | ||
|
|
||
| explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) { | ||
| bool ok = this->loadMaskFile(mask_filename); | ||
| if (!ok) { | ||
| LOG(WARNING) << "Fail to load mask file " << mask_filename | ||
| << " in current directory, searching in init_model_path"; | ||
| std::string combineMaskFilename = | ||
| path::join(FLAGS_init_model_path, mask_filename); | ||
| CHECK(this->loadMaskFile(combineMaskFilename)) | ||
| << "Cannot load " << mask_filename << " in ./" << mask_filename | ||
| << " and " << combineMaskFilename; | ||
| } | ||
| VLOG(3) << mask_filename << " mask size = " << this->mask_.size(); | ||
| explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) | ||
| : initCount_(0) { | ||
| sparsityRatio_ = hookConfig.sparsity_ratio(); | ||
| } | ||
|
|
||
| static bool sortPairAscend(const std::pair<real, size_t>& pair1, | ||
| const std::pair<real, size_t>& pair2) { | ||
| return pair1.first > pair2.first; | ||
| } | ||
|
|
||
| void update(Parameter* para) { | ||
|
|
@@ -73,62 +56,53 @@ class StaticPruningHook : public IParameterUpdaterHook { | |
| } | ||
| } | ||
|
|
||
| void init(Parameter* para) { | ||
| size_t initCount = this->initCount_.fetch_add(1); | ||
| CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | ||
| "in same ParamterUpdater"; | ||
| VLOG(3) << "Initialize Parameter " << para; | ||
| SetDevice device(para->getDeviceId()); | ||
| void generateMask(Parameter* para) { | ||
| VectorPtr vec = para->getBuf(PARAMETER_VALUE); | ||
| maskTemp_ = Vector::create(para->getSize(), false); | ||
|
||
| maskTemp_->zeroMem(); | ||
| real* dataPtr = maskTemp_->getData(); | ||
| size_t sparsityNum = para->getSize() * (1 - sparsityRatio_); | ||
|
||
|
|
||
| auto maskVec = Vector::create(this->mask_.size(), false); | ||
| { // Initialize maskVec with float mask vector | ||
| real* dataPtr = maskVec->getData(); | ||
| size_t i = 0; | ||
| for (bool m : mask_) { | ||
| dataPtr[i++] = m ? 1.0 : 0.0; | ||
| } | ||
| } | ||
| VectorPtr vecCpu = Vector::create(para->getSize(), false); | ||
| vecCpu->copyFrom(*vec); | ||
| std::vector<std::pair<real, size_t>> param; | ||
|
|
||
| for (size_t i = 0; i < para->getSize(); i++) | ||
| param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); | ||
|
|
||
| std::partial_sort(param.begin(), | ||
| param.begin() + sparsityNum, | ||
| param.end(), | ||
| sortPairAscend); | ||
| for (size_t i = 0; i < sparsityNum; i++) dataPtr[param[i].second] = 1.0; | ||
|
|
||
| // Currently just use a mask vector for hack. | ||
| // @TODO(yuyang18): Implemented the mask operation in vector. | ||
| if (para->useGpu()) { | ||
| maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); | ||
| maskVec_->copyFrom(*maskVec); | ||
| maskVec_ = Vector::create(para->getSize(), para->useGpu()); | ||
| maskVec_->copyFrom(*maskTemp_); | ||
| } else { | ||
| maskVec_ = maskVec; | ||
| maskVec_ = maskTemp_; | ||
| } | ||
| } | ||
|
|
||
| void init(Parameter* para) { | ||
| generateMask(para); | ||
| size_t initCount = this->initCount_.fetch_add(1); | ||
| CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " | ||
| "in same ParamterUpdater"; | ||
| VLOG(3) << "Initialize Parameter " << para; | ||
| SetDevice device(para->getDeviceId()); | ||
|
|
||
| auto& vec = para->getBuf(PARAMETER_VALUE); | ||
| vec->dotMul(*maskVec_); | ||
| } | ||
|
|
||
| private: | ||
| bool loadMaskFile(const std::string& mask_filename) { | ||
| std::ifstream fin; | ||
| fin.open(mask_filename); | ||
| if (fin.is_open()) { | ||
| StaticMaskHeader header; | ||
| fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader)); | ||
| CHECK_EQ(header.version, 0UL); | ||
| mask_.resize(header.size); | ||
| uint8_t buf; | ||
| for (size_t i = 0; i < header.size; ++i, buf <<= 1) { | ||
| if (i % 8 == 0) { | ||
| fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t)); | ||
| } | ||
| mask_[i] = buf & 0x80; | ||
| } | ||
| fin.close(); | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| SameThreadChecker updateThreadChecker_; | ||
| std::atomic<size_t> initCount_; | ||
| VectorPtr maskVec_; | ||
| std::vector<bool> mask_; | ||
| VectorPtr maskTemp_; | ||
|
||
| real sparsityRatio_; | ||
| }; | ||
|
|
||
| IParameterUpdaterHook::IParameterUpdaterHook() {} | ||
|
|
@@ -166,10 +140,10 @@ static IParameterUpdaterHook* createImpl( | |
| const ParameterUpdaterHookConfig& config) { | ||
| auto& type = config.type(); | ||
| if (type == "pruning") { | ||
| if (config.has_purning_mask_filename()) { | ||
| return new StaticPruningHook(config.purning_mask_filename()); | ||
| } | ||
| return new StaticPruningHook(config); | ||
| } | ||
|
|
||
| LOG(FATAL) << "Unknown Hook type: " << type; | ||
|
||
| return nullptr; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,8 +25,9 @@ enum ParameterInitStrategy { | |
| } | ||
|
|
||
| message ParameterUpdaterHookConfig { | ||
| // hook type such as 'pruning' | ||
| required string type = 1; | ||
| optional string purning_mask_filename = 2; | ||
| optional double sparsity_ratio = 2 [default = 0.8]; | ||
|
||
| } | ||
|
|
||
| message ParameterConfig { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,8 @@ | |
|
|
||
| from paddle.trainer.config_parser import * | ||
| __all__ = [ | ||
| 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' | ||
| 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', | ||
| 'ExtraLayerAttribute' | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -55,6 +56,33 @@ def is_compatible_with(x, Type): | |
| return False | ||
|
|
||
|
|
||
| class HookAttribute(object): | ||
| """ | ||
| Hook Attribute object. The hook is an auxiliary operation that occurs | ||
| during network propagation. | ||
| NOTE: IT IS A HIGH LEVEL USER INTERFACE. | ||
|
||
|
|
||
| :param type: Hook type, eg: 'pruning' | ||
|
||
| :type type: string | ||
|
|
||
| :param sparsity_ratio: Must be specified if hook type is 'pruning' | ||
| :type sparsity_ratio: float or None | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, type, sparsity_ratio=None): | ||
| self.type = type | ||
| self.sparsity_ratio = sparsity_ratio | ||
| if self.sparsity_ratio is not None: | ||
| assert is_compatible_with( | ||
| self.sparsity_ratio, | ||
| float), 'sparisity_ratio must be float type' | ||
| assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' | ||
|
||
|
|
||
| def __call__(self): | ||
| return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) | ||
|
|
||
|
|
||
| class ParameterAttribute(object): | ||
| """ | ||
| Parameter Attributes object. To fine-tuning network training process, user | ||
|
|
@@ -109,7 +137,8 @@ def __init__(self, | |
| learning_rate=None, | ||
| momentum=None, | ||
| gradient_clipping_threshold=None, | ||
| sparse_update=False): | ||
| sparse_update=False, | ||
| update_hooks=None): | ||
| self.attr = {} | ||
|
|
||
| if is_static: | ||
|
|
@@ -162,6 +191,9 @@ def __init__(self, | |
| self.attr['gradient_clipping_threshold'] = \ | ||
| gradient_clipping_threshold | ||
|
|
||
| if update_hooks: | ||
| self.attr['update_hooks'] = update_hooks | ||
|
|
||
| def set_default_parameter_name(self, name): | ||
| """ | ||
| Set default parameter name. If parameter not set, then will use default | ||
|
|
@@ -237,5 +269,6 @@ def to_kwargs(attr): | |
| return attr.attr | ||
|
|
||
|
|
||
| HookAttr = HookAttribute | ||
| ParamAttr = ParameterAttribute | ||
| ExtraAttr = ExtraLayerAttribute | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: specific -> specify, start -> started
More deatils can see -> More details can be found