-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Implement framework::Variable #2587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,3 @@ | ||
| //#include <stdexcept> | ||
| //#include <unittest/unittest.h> | ||
| #include <sstream> | ||
| #include <vector> | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| /* | ||
| Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include <typeindex> | ||
| #include <typeinfo> | ||
|
|
||
| #include "paddle/platform/assert.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| class Variable { | ||
| public: | ||
| template <typename T> | ||
| const T& Get() const { | ||
| PADDLE_ASSERT(holder_ != nullptr); | ||
| PADDLE_ASSERT(std::type_index(typeid(T)) == | ||
| std::type_index(holder_->Type())); | ||
| return *static_cast<const T*>(holder_->Ptr()); | ||
| } | ||
|
|
||
| template <typename T> | ||
| T* GetMutable() { | ||
| if (holder_ == nullptr || | ||
| std::type_index(typeid(T)) != std::type_index(holder_->Type())) { | ||
| holder_.reset(new PlaceholderImpl<T>(new T())); | ||
| } | ||
| return static_cast<T*>(holder_->Ptr()); | ||
| } | ||
|
|
||
| private: | ||
| struct Placeholder { | ||
| virtual ~Placeholder() {} | ||
| virtual const std::type_info& Type() const = 0; | ||
| virtual void* Ptr() const = 0; | ||
| }; | ||
|
|
||
| // Placeholder hides type T, so it doesn't appear as a template | ||
| // parameter of Variable. | ||
| template <typename T> | ||
| struct PlaceholderImpl : public Placeholder { | ||
| PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {} | ||
|
|
||
| virtual const std::type_info& Type() const { return type_; } | ||
| virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also want to use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since cuDNN and other BLAS libraries would anyway need a raw pointer from unique_ptr or other smart pointers, I think it is OK (and indeed a mandatory) to call BTW, I guess you wanted to say "We had thought about using std::unique_ptr ..."? :-)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. Thanks! :) |
||
|
|
||
| std::unique_ptr<T> ptr_; | ||
| const std::type_info& type_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a unique id for the different type in the Caffe2's Here, we store the |
||
| }; | ||
|
|
||
| std::unique_ptr<Placeholder> | ||
| holder_; // pointers to a PlaceholderImpl object indeed. | ||
| }; | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Design Doc: Variable | ||
|
|
||
|
|
||
| Variable is also known as *blob* in MxNet and Caffe2. It is the input and output type of operators, where a neural network is a graph of operators. | ||
|
|
||
| ## Requirements: Lazy Memory Allocation | ||
|
|
||
| For the flexibility of a DL system, a variable should be able to contain any typed value -- a tensor in most cases, but could also be some integer IDs or a scope of other variables in the case of RNN. | ||
|
|
||
| To use the minimum amount of memory, we'd like that a variable to allocate memory when it has to, or, lazy memory allocation. Let's take the following example: | ||
|
|
||
| ```cpp | ||
| Variable vr, v1, v2; | ||
|
|
||
| Tensor* t1 = new Tensor(); | ||
| Tensor* t2 = new Tensor(); | ||
|
|
||
| Randomize( | ||
| /* malloc */ v1.GetMutable<Tensor>().mutable_data<float16>(DDim(100,200)), | ||
| /* size */ t1.Size()); | ||
|
|
||
| Randomize( | ||
| /* malloc */ v2.GetMutable<Tensor>().mutable_data<float16>(DDim(200,300)), | ||
| /* size */ t2.Size()); | ||
|
|
||
| Mult( | ||
| /*result*/ vr.GetMutable<Tensor>().mutable_data<v1.Type()>(SizeOfMult(v1, v2)), | ||
| /*input1*/ v1.Get<Tensor>().data(), | ||
| /*input2*/ v2.Get<Tensor>().data()); | ||
| ``` | ||
|
|
||
| We see that a variable holds nothing until `Variable::GetMutable<Tensor>()` allocates a tensor and puts it in the variable. Similarly, a tensor gets its memory until `Tensor::mutable_data()`. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that 'lazy allocation' is split into two parts: the part in In Tensor<Device, T, D> {
// return an empty tensor with ptr_ pointed to nullptr
Tensr();
// return a tensor with newly allocated memory
Tensor(Dim<D> size);
// check whether ptr_ points to nullptr
bool is_empty();
}In The replacement costs little time because the empty tensor holds no memory. The benefit of this design is that all implement of 'lazy allocation' is now limited in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Variable can hold anything -- it doens't have to be a Tensor, it could be a Scope even. Thus we might not initialize a Variable with an empty tensor.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks! |
||
|
|
||
| This syntax for lazy memory allocation when we call `Randomize` and `Mult`, those functions that mutate the variable, so it saves us some line of C++ code. | ||
|
|
||
|
|
||
| ## Implementation: Type Hiding | ||
|
|
||
| To make memory allocation lazy, we cannot assume that we know the type held by a variable at definition time. In other words, `class Variable` cannot be a template `template <T> class Variable`. | ||
|
|
||
| Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, who can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`. | ||
|
|
||
| But anyway, Variable needs to know `T` so could it `delete<T>(ptr)` and so could `Variable::Get` checks the expected type and the saved object's type. | ||
|
|
||
| We save `T` in `PlaceholderImpl`, the implementation of `Placeholder`. Please be aware that `PlaceholderImpl` is a class template and `T` is passed in as a template parameter. | ||
|
|
||
| Because `PlaceholderImpl` knows `T`, it can save and return `typeid(T)` for the type comparison in `Variable::Get` and `Variable::GetMutable`. | ||
|
|
||
|
|
||
| ## Conclusion | ||
|
|
||
| The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from definition something like `caffe2::TypeMata`, which takes hundreds of lines of C++ code. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| /* | ||
| Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
|
|
||
| #include "gtest/gtest.h" | ||
| #include "paddle/framework/variable.h" | ||
|
|
||
| TEST(Variable, GetMutable) { | ||
| using paddle::framework::Variable; | ||
|
|
||
| struct Tensor { | ||
| int content_; | ||
| }; | ||
|
|
||
| std::unique_ptr<Variable> v(new Variable()); | ||
|
|
||
| Tensor* t = v->GetMutable<Tensor>(); | ||
| t->content_ = 1234; | ||
|
|
||
| const Tensor& tt = v->Get<Tensor>(); | ||
| EXPECT_EQ(1234, tt.content_); | ||
|
|
||
| std::string* s = v->GetMutable<std::string>(); | ||
| *s = "hello"; | ||
|
|
||
| const std::string& ss = v->Get<std::string>(); | ||
| EXPECT_EQ("hello", ss); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the variable design, we want to add
Reset(T* ptr)function, which can delete the hosted pointer and set a new ptr. This interface is useful when we want to reuse aVariableand do not need to new another emptyVariable. Should we add this function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had
Resetbut deleted it later, because it is out of the imagined usage in the design doc (README.md). I think we can add it at any time in the future if we believe we do need it.