-
Notifications
You must be signed in to change notification settings - Fork 5.9k
move net_design to framework #2553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
5886959
move net_design to framework
Superjomn a00900a
add nullptr to Network.Run
Superjomn 4fb581b
add usage
Superjomn bb33f7a
remove scope* in construction
Superjomn a7dbfe0
remove AddOp from user's interface
Superjomn 8cf2d60
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
Superjomn 2933da0
net new design
Superjomn 7d5b1b2
add more details to net design
Superjomn b6450da
disable CreateOp from user
Superjomn a46506d
move w1 to scope
Superjomn 3c7bf55
add rnn compatibility
Superjomn ebe143d
fix english grammer
Superjomn 4e5a359
rrn -> rnn
Superjomn 9be4c74
change CreateNet result to unique_ptr
Superjomn 7102b80
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
Superjomn 2953215
make Run() stateless
Superjomn 58fdcd0
complete a unique_ptr's >
Superjomn 417b279
pr details fix
Superjomn b68c8bf
add InferShape
Superjomn 62823b1
rename "ScratchNet" -> "PlainNet"
Superjomn deff7cc
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
Superjomn 740f3a6
detail config
Superjomn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # Network Design | ||
|
|
||
| `Network` is the container and controller of a set of operators, | ||
| users can build a real network from a `NetDef` in protobuf message | ||
| and use `Network.Run()` to run all the operators in the network. | ||
|
|
||
| The `Network` will | ||
|
|
||
| - manage all the operators contained in the network. | ||
| - not own any `Variable`. | ||
|
|
||
| # API | ||
|
|
||
| ## NetworkBase | ||
| To make the `Network` extendable, a base class is defined like this | ||
|
|
||
| ```c++ | ||
| // operator's index stored in a network. | ||
| typedef int OpIndex; | ||
|
|
||
| // The minimum a network should be implemented. | ||
| class NetworkBase { | ||
| public: | ||
| // `def` is a proto message that describe the structure of a network. | ||
| NetworkBase(); | ||
|
|
||
| // Infer the shapes of variables required by operators in the network. The | ||
| // `scope` will be mutated according to the inferred shapes. | ||
| virtual bool InferShape(Scope *scope) = 0; | ||
|
|
||
| // run all the operators and return success(true) or not, all the | ||
| // variables are located in `scope`. `begin` and `end` specify the scope of | ||
| // `ops_` to run, If no positive indexes are provided, all operators in `ops_` | ||
| // will run. | ||
| virtual bool Run(Scope *scope, OpIndex begin = -1, | ||
| OpIndex end = -1) const = 0; | ||
| }; | ||
| ``` | ||
|
|
||
| All network implementations should build networks from a protobuf message which | ||
| describes the structure of a real network; `Run` method should be implemented by | ||
| all implementations to offer a universal method to forward or backward compute a network. | ||
|
|
||
| A method of factory pattern can be defined like | ||
|
|
||
| ```c++ | ||
| std::unique<NetworkBase> CreateNet(const NetDef& def) { | ||
| switch (def.model_type()) { | ||
| case NN: | ||
| return new Network(def); | ||
| case Recursive: | ||
| return new RecursiveNet(def); | ||
| case Recurrent: | ||
| return new RecurrentNet(def); | ||
| } | ||
| return nullptr; | ||
| } | ||
| ``` | ||
|
|
||
| Network is designed as the container of operators, to make it more extendable, | ||
| we decoupling it from the related variable resources. | ||
|
|
||
| `Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes. | ||
|
|
||
| Finally, `NetworkBase` can be used as followed | ||
|
|
||
| ```c++ | ||
| Scope default_scope; | ||
| auto net = CreateNet(def); | ||
|
|
||
| if (net) { | ||
| net.Run(&default_scope); | ||
| } | ||
| ``` | ||
|
|
||
| ## A Simple Network Implementation | ||
|
|
||
| A very basic implementation is as followed, all it does is simply to run every operators in sequence. | ||
|
|
||
| ```c++ | ||
| class PlainNet final : public NetworkBase { | ||
| public: | ||
| // Create a network describe by `def`. NetDef is the definition of a network. | ||
| PlainNet(const NetDef &def); | ||
|
|
||
| virtual bool InferShape(Scope *scope) override; | ||
|
|
||
| // Run all the operators with the `scope`, if no scope is provided, default | ||
| // scope will be used instead. | ||
| virtual bool Run(Scope *scope = nullptr, OpIndex begin, | ||
| OpIndex end) const override; | ||
|
|
||
| const std::vector<Operator> &GetOps() const; | ||
|
|
||
| std::vector<Operator> *MutableOps(); | ||
|
|
||
| protected: | ||
| // Create operators accordding to `def`. | ||
| bool CreateNet(const NetDef &def); | ||
|
|
||
| // Add a operator which is identified as `type` and has attributes described | ||
| // in `attrs`, the `inputs` are the keys of readonly input variables, | ||
| // `outputs` are keys of mutable output variables. An `OpIndex` will be | ||
| // returned which indicates the offset of the new operator in `ops_`. | ||
| OpIndex AddOp(const std::string &type, const std::vector<string> &inputs, | ||
| const std::vector<string> &outputs, | ||
| const OprAttr &attrs = OprAttr()); | ||
|
|
||
| private: | ||
| // the operators owned by `Network`. | ||
| std::vector<Operator> ops_; | ||
| }; | ||
| ``` | ||
|
|
||
| `PlainNet` will create operators so that a private member `ops_` is defined, | ||
| the operators are created by `CreateNet`, and each operator is created by `AddOp`. | ||
|
|
||
|
|
||
| ## Usage | ||
| `PlainNet` can be used to define and run a network as followed | ||
|
|
||
| ```c++ | ||
| // create an empty scope located on CPU device. | ||
| Scope scope(CPUPlace()); | ||
|
|
||
| // create and init variables described in `net_desc`. | ||
| scope.CreateVariables(net_desc); | ||
| scope.InitVariables(net_desc); | ||
|
|
||
| // create a network according to `net_desc` | ||
| auto net = CreateNet(net_desc); | ||
|
|
||
| // run the network providing the `scope`. | ||
| net.Run(&scope); | ||
| ``` | ||
|
|
||
| ## Compatibility with RNN | ||
|
|
||
| Benefit from the decoupling of `PlainNet.Run` and `Scope`, `PlainNet` is compatible with future RNN design, | ||
| for example we can implement a simple recurrent neural network as followed | ||
|
|
||
| ```c++ | ||
| // copy some `vars` form `source` to `target` | ||
| void Copy(const Scope &source, Scope &target, | ||
| const std::vector<std::string> &vars); | ||
|
|
||
| Scope default_scope; | ||
| // some initial mutations on `default_scope` here. | ||
|
|
||
| auto rnn_step_net = PlainNet(rnn_step_net_def); | ||
|
|
||
| // Create rnn's states, the last scope is used to store rnn outputs. | ||
| Scope *rnn_states = new Scope[num_states + 1]; | ||
|
|
||
| for (int i = 0; i < num_states + 1; i++) { | ||
| // Initialize all rnn state scopes, copy parameters and so on. | ||
| rnn_states[i].CreateVars(rnn_step_net_def); | ||
| Copy(default_scope, rnn_states[i], rnn_related_vars); | ||
| // Prepare rnn's inlinks, just copy inlink variables to each state. | ||
| Copy(default_scope, rnn_states[i], inlink_vars); | ||
| } | ||
|
|
||
| // Run the rnn. | ||
| for (int i = 0; i < num_states; i++) { | ||
| rnn_step_net.Run(rnn_states[i]); | ||
| // Copy current state's state variables to next state, the related variables | ||
| // are named like "previous_state_xxx". | ||
| Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars) | ||
| } | ||
|
|
||
| // Copy rnn's final outputs to `default_scope`. | ||
| Copy(rnn_states[num_states], default_scope, outlink_vars); | ||
| ``` | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NetworkBasehas a member function calledInferShape(Scope*)to infer shapes and create variables in scope, is its function duplicate withscope.CreateVariables?Because the variables' information is located in
OpDefand there is no field calledvarsinNetDefthat holds all the information of variables required by the network, isscope.CreateVariablesnecessary?@wangkuiyi @reyoung @jacquesqiao
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variableinformation is not stored inNetDef, and inOpDef, there will only be input/output names forVariable. SoCreateVariablesbe performed only byNetDef.Maybe there should be a high-level concept called
NetworkBuilderorModel. ANetworkBuildertakes aNetworkBaseand aScopeas its inputs. InNetworkBuilder, the user can add layers, add backward ops, and get all parameters. DuringAddLayerfunctions, thevariableis created by Network Builder.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge
NetworkBuilerintoNetworkBase?NetworkBaseis the base class of all "network" implementations, with a different implementation, NetBuilder may be different.For example,
SimpleNet's network structure is different withDAGNet's structure,so a different NetBuilder.ApplyGradient may be needed. In other words, NetBuilder is coupled to specific implementation of NetworkBase.
If we merge NetBuilder's functions into
NetworkBase's specific implementations, then it is natural to have different "BackwardFrom" implementations.@reyoung
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the logic of
ApplyGradientis same for each implementation ofNetworkBase. The different betweenPlainNetand other implementation is how to implementAddOperator.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I create another PR with NetBuilder
NetBuilderis treated as a user-friendly syntax API.NetBaseis the detailed implementation of Network-related APIs.Users can create
BaseNetusingNetBuilder, but the detail ofBaseNetis not visible to them.With this two-level design, we will focus our use-related API on
NetBuilderand implementation details only onNetBase.Both of them are not coupled with each other.
@reyoung