Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions paddle/framework/net_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Network Design

`Network` is the container and controller of a set of operators,
users can build a real network from a `NetDef` in protobuf message
and use `Network.Run()` to run all the operators in the network.

The `Network` will

- manage all the operators contained in the network.
- not own any `Variable`.

# API

## NetworkBase
To make the `Network` extendable, a base class is defined like this

```c++
// operator's index stored in a network.
typedef int OpIndex;

// The minimum a network should be implemented.
class NetworkBase {
public:
// `def` is a proto message that describe the structure of a network.
NetworkBase();

// Infer the shapes of variables required by operators in the network. The
// `scope` will be mutated according to the inferred shapes.
virtual bool InferShape(Scope *scope) = 0;

// run all the operators and return success(true) or not, all the
// variables are located in `scope`. `begin` and `end` specify the scope of
// `ops_` to run, If no positive indexes are provided, all operators in `ops_`
// will run.
virtual bool Run(Scope *scope, OpIndex begin = -1,
OpIndex end = -1) const = 0;
};
```

All network implementations should build networks from a protobuf message which
describes the structure of a real network; `Run` method should be implemented by
all implementations to offer a universal method to forward or backward compute a network.

A method of factory pattern can be defined like

```c++
std::unique<NetworkBase> CreateNet(const NetDef& def) {
switch (def.model_type()) {
case NN:
return new Network(def);
case Recursive:
return new RecursiveNet(def);
case Recurrent:
return new RecurrentNet(def);
}
return nullptr;
}
```

Network is designed as the container of operators, to make it more extendable,
we decoupling it from the related variable resources.

`Run(Scope* scope)` takes the scope as a argument so that it can run in different scopes.

Finally, `NetworkBase` can be used as followed

```c++
Scope default_scope;
auto net = CreateNet(def);

if (net) {
net.Run(&default_scope);
}
```

## A Simple Network Implementation

A very basic implementation is as followed, all it does is simply to run every operators in sequence.

```c++
class PlainNet final : public NetworkBase {
public:
// Create a network describe by `def`. NetDef is the definition of a network.
PlainNet(const NetDef &def);

virtual bool InferShape(Scope *scope) override;

// Run all the operators with the `scope`, if no scope is provided, default
// scope will be used instead.
virtual bool Run(Scope *scope = nullptr, OpIndex begin,
OpIndex end) const override;

const std::vector<Operator> &GetOps() const;

std::vector<Operator> *MutableOps();

protected:
// Create operators accordding to `def`.
bool CreateNet(const NetDef &def);

// Add a operator which is identified as `type` and has attributes described
// in `attrs`, the `inputs` are the keys of readonly input variables,
// `outputs` are keys of mutable output variables. An `OpIndex` will be
// returned which indicates the offset of the new operator in `ops_`.
OpIndex AddOp(const std::string &type, const std::vector<string> &inputs,
const std::vector<string> &outputs,
const OprAttr &attrs = OprAttr());

private:
// the operators owned by `Network`.
std::vector<Operator> ops_;
};
```

`PlainNet` will create operators so that a private member `ops_` is defined,
the operators are created by `CreateNet`, and each operator is created by `AddOp`.


## Usage
`PlainNet` can be used to define and run a network as followed

```c++
// create an empty scope located on CPU device.
Scope scope(CPUPlace());

// create and init variables described in `net_desc`.
scope.CreateVariables(net_desc);
Copy link
Contributor Author

@Superjomn Superjomn Jun 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NetworkBase has a member function called InferShape(Scope*) to infer shapes and create variables in scope, is its function duplicate with scope.CreateVariables?

Because the variables' information is located in OpDef and there is no field called vars in NetDef that holds all the information of variables required by the network, is scope.CreateVariables necessary?
@wangkuiyi @reyoung @jacquesqiao

Copy link
Collaborator

@reyoung reyoung Jun 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable information is not stored in NetDef, and in OpDef, there will only be input/output names for Variable. So CreateVariables be performed only by NetDef.

Maybe there should be a high-level concept called NetworkBuilder or Model. A NetworkBuilder takes a NetworkBase and a Scope as its inputs. In NetworkBuilder, the user can add layers, add backward ops, and get all parameters. During AddLayer functions, the variable is created by Network Builder.

NetworkBase network;
Scope scope;
Variable* image = scope.CreateVariable("Image");
Variable* label = scope.CreateVariable("Label");
NetworkBuilder builder(&network, &scope);

Variable* fc_out = builder.FCLayer(input=image, size=100, activation="Sigmoid");
Variable* prediction = builder.FCLayer(input=fc_out, size=10, activation="Sigmoid");
Variable* loss = builder.CrossEntropy(input=prediction, label=label);
Variable* avg_loss = builder.Mean(loss);

auto allParams = builder.Parameters();
builder.BackwardFrom(avg_loss)
builder.AddOptimization(1e-4, "adam");

// train one mini-batch
network.run(&scope);

Copy link
Contributor Author

@Superjomn Superjomn Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge NetworkBuiler into NetworkBase?
NetworkBase is the base class of all "network" implementations, with a different implementation, NetBuilder may be different.
For example, SimpleNet 's network structure is different with DAGNet 's structure,
so a different NetBuilder.ApplyGradient may be needed. In other words, NetBuilder is coupled to specific implementation of NetworkBase.

If we merge NetBuilder's functions into NetworkBase 's specific implementations, then it is natural to have different "BackwardFrom" implementations.
@reyoung

Copy link
Collaborator

@reyoung reyoung Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the logic of ApplyGradient is same for each implementation of NetworkBase. The different between PlainNet and other implementation is how to implement AddOperator.

Copy link
Contributor Author

@Superjomn Superjomn Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I create another PR with NetBuilder
NetBuilder is treated as a user-friendly syntax API.
NetBase is the detailed implementation of Network-related APIs.

Users can create BaseNet using NetBuilder, but the detail of BaseNet is not visible to them.
With this two-level design, we will focus our use-related API on NetBuilder and implementation details only on NetBase.

Both of them are not coupled with each other.

@reyoung

scope.InitVariables(net_desc);

// create a network according to `net_desc`
auto net = CreateNet(net_desc);

// run the network providing the `scope`.
net.Run(&scope);
```

## Compatibility with RNN

Benefit from the decoupling of `PlainNet.Run` and `Scope`, `PlainNet` is compatible with future RNN design,
for example we can implement a simple recurrent neural network as followed

```c++
// copy some `vars` form `source` to `target`
void Copy(const Scope &source, Scope &target,
const std::vector<std::string> &vars);

Scope default_scope;
// some initial mutations on `default_scope` here.

auto rnn_step_net = PlainNet(rnn_step_net_def);

// Create rnn's states, the last scope is used to store rnn outputs.
Scope *rnn_states = new Scope[num_states + 1];

for (int i = 0; i < num_states + 1; i++) {
// Initialize all rnn state scopes, copy parameters and so on.
rnn_states[i].CreateVars(rnn_step_net_def);
Copy(default_scope, rnn_states[i], rnn_related_vars);
// Prepare rnn's inlinks, just copy inlink variables to each state.
Copy(default_scope, rnn_states[i], inlink_vars);
}

// Run the rnn.
for (int i = 0; i < num_states; i++) {
rnn_step_net.Run(rnn_states[i]);
// Copy current state's state variables to next state, the related variables
// are named like "previous_state_xxx".
Copy(rnn_states[i], rnn_states[i + 1], pre_state_vars)
}

// Copy rnn's final outputs to `default_scope`.
Copy(rnn_states[num_states], default_scope, outlink_vars);
```