-
Notifications
You must be signed in to change notification settings - Fork 5.9k
fea/init tensorrt engine #10003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fea/init tensorrt engine #10003
Changes from 27 commits
a3140d3
a60189f
b95d819
87fc090
8dda580
92480b5
5891896
1b475b3
9d617b8
0e8e85f
63b6a74
d492547
4f0a2ab
e220226
5132a2b
dc23dc5
1fe9f63
cf4f092
f1b5040
9699574
aa7ab53
1d13858
4da8cbd
5463325
74ea1f6
610f290
f273eef
57c0ddb
25397ca
6d89b54
97a34ac
5b8de3b
bbf19cb
4c0ce9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "paddle/fluid/framework/framework.pb.h" | ||
|
|
||
| namespace paddle { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按照Paddle对于命名空间的使用规则,这里应该还有一层
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可能会有下面的目录
|
||
| /* | ||
| * EngineBase is the base class of all inference engines. An inference engine | ||
| * takes a paddle program as input, and output the result in paddle Tensor | ||
|
||
| * format. It can be used to optimize performance of computation subgraphs, for | ||
| * example, break down the original model into subgraphs and execute each | ||
|
||
| * subgraph in different engines. | ||
| * | ||
| * For example: | ||
| * When inference, the resnet50 model can put most of the model into subgraph | ||
| * and run it on a TensorRT engine. | ||
| * | ||
| * There are several engines such as TensorRT and other internal frameworks, so | ||
|
||
| * an EngineBase is put forward to give an unified interface for all the | ||
| * different engine implemention. | ||
| */ | ||
| class EngineBase { | ||
| public: | ||
| // TODO fix it latter | ||
|
||
| using PbType = int; // proto::BlockDesc; | ||
|
|
||
| // Build the model and do some preparation, for example, in TensorRT, run | ||
| // createInferBuilder, buildCudaEngine. | ||
| virtual void Build(const PbType& paddle_model) = 0; | ||
|
|
||
| // Execute the engine, that will run the inference network. | ||
| virtual void Execute(int batch_size) = 0; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Execute函数是不是最好以Paddle的LoDTensor类型作为参数?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 engine 只是个工具类,会由 我的理解, engine 的 input 和 output 以及个数都不固定,有 这个类会暴露出比较多的小接口,这些接口会帮助构建 tensorrt 的network 以及 runtime engine, 中间小接口基本是不可少的。 最重要的用处是在
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这两个和Convert类里的ConvertInput和ConvertOutput里,除了转的那步,功能很类似。有更好的设计办法么? Decl这四个字母含义不够清晰,能否就叫Add呢?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 会改成 DeclareInput,表示在 TensorRT network中添加 data 节点 |
||
|
|
||
| virtual ~EngineBase() {} | ||
|
|
||
| }; // class EngineBase | ||
|
|
||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) | ||
| nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/inference/tensorrt/engine.h" | ||
|
|
||
| #include <NvInfer.h> | ||
| #include <cuda.h> | ||
| #include <glog/logging.h> | ||
| #include "paddle/fluid/platform/dynload/tensorrt.h" | ||
| #include "paddle/fluid/platform/enforce.h" | ||
|
|
||
| namespace dy = paddle::platform::dynload; | ||
|
|
||
| namespace paddle { | ||
|
|
||
| size_t AccumDims(nvinfer1::Dims dims) { | ||
| size_t num = dims.nbDims == 0 ? 0 : 1; | ||
| for (int i = 0; i < dims.nbDims; i++) { | ||
| PADDLE_ENFORCE_GT(dims.d[i], 0); | ||
| LOG(INFO) << "dim.d: " << i << " " << dims.d[i]; | ||
| num *= dims.d[i]; | ||
| } | ||
| return num; | ||
| } | ||
|
|
||
| const int kDataTypeSize[] = { | ||
| 4, // kFLOAT | ||
| 2, // kHALF | ||
| 1, // kINT8 | ||
| 4 // kINT32 | ||
| }; | ||
|
||
|
|
||
| void TensorrtEngine::Build(const PbType& paddle_model) { | ||
| PADDLE_ENFORCE(false, "not implemented"); | ||
| } | ||
|
|
||
| void TensorrtEngine::Execute(int batch_size) { | ||
| infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr); | ||
| cudaStreamSynchronize(*stream_); | ||
| } | ||
|
|
||
| TensorrtEngine::~TensorrtEngine() { | ||
| // clean buffer | ||
| for (auto& buffer : buffers_) { | ||
| if (buffer != nullptr) { | ||
| PADDLE_ENFORCE_EQ(0, cudaFree(buffer)); | ||
| buffer = nullptr; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| namespace { | ||
|
|
||
| class Logger : public nvinfer1::ILogger { | ||
|
||
| public: | ||
| void log(nvinfer1::ILogger::Severity severity, const char* msg) override { | ||
| switch (severity) { | ||
| case Severity::kINFO: | ||
| LOG(INFO) << msg; | ||
| break; | ||
| case Severity::kWARNING: | ||
| LOG(WARNING) << msg; | ||
| break; | ||
| case Severity::kINTERNAL_ERROR: | ||
| case Severity::kERROR: | ||
| LOG(ERROR) << msg; | ||
| break; | ||
| default: | ||
| break; | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| // The following two API are implemented in TensorRT's header file, cannot load | ||
| // from the dynamic library. So create our own implementation and directly | ||
| // trigger the method from the dynamic library. | ||
| nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { | ||
| return static_cast<nvinfer1::IBuilder*>( | ||
| dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); | ||
|
||
| } | ||
| nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { | ||
| return static_cast<nvinfer1::IRuntime*>( | ||
| dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); | ||
|
||
| } | ||
| } // namespace | ||
|
|
||
| void TensorrtEngine::InitNetwork() { | ||
| Logger logger; | ||
| infer_builder_.reset(createInferBuilder(logger)); | ||
| infer_network_.reset(infer_builder_->createNetwork()); | ||
| } | ||
|
|
||
| void TensorrtEngine::FreezeNetwork() { | ||
| PADDLE_ENFORCE(infer_builder_ != nullptr, | ||
| "Call InitNetwork first to initialize network."); | ||
| PADDLE_ENFORCE(infer_network_ != nullptr, | ||
| "Call InitNetwork first to initialize network."); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 108行和110行的enforce内容一样?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InitNetwork() |
||
| // build engine. | ||
| infer_builder_->setMaxBatchSize(max_batch_); | ||
| infer_builder_->setMaxWorkspaceSize(max_workspace_); | ||
|
|
||
| infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); | ||
| PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); | ||
|
|
||
| infer_context_.reset(infer_engine_->createExecutionContext()); | ||
|
|
||
| // allocate GPU buffers. | ||
| buffers_.resize(buffer_sizes_.size(), nullptr); | ||
| for (auto& item : buffer_sizes_) { | ||
| if (item.second == 0) { | ||
| auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); | ||
| item.second = kDataTypeSize[static_cast<int>( | ||
| infer_engine_->getBindingDataType(slot_offset))] * | ||
| AccumDims(infer_engine_->getBindingDimensions(slot_offset)); | ||
| } | ||
| PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second)); | ||
| } | ||
| } | ||
|
|
||
| nvinfer1::ITensor* TensorrtEngine::DeclInput(const std::string& name, | ||
| data_type dtype, | ||
| const dim_type& dim) { | ||
| PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", | ||
| name); | ||
|
|
||
| PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); | ||
| auto* input = infer_network_->addInput(name.c_str(), dtype, dim); | ||
| PADDLE_ENFORCE(input, "infer network add input %s failed", name); | ||
|
|
||
| buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim); | ||
| return input; | ||
| } | ||
|
|
||
| void TensorrtEngine::DeclOutput(nvinfer1::ILayer* layer, int offset, | ||
| const std::string& name) { | ||
|
||
| PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", | ||
| name); | ||
|
|
||
| auto* output = layer->getOutput(offset); | ||
| PADDLE_ENFORCE(output != nullptr); | ||
| output->setName(name.c_str()); | ||
| infer_network_->markOutput(*output); | ||
| buffer_sizes_[name] = 0; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么156行设成0?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. + comment |
||
| // * data_size; | ||
|
||
| } | ||
|
|
||
| void* TensorrtEngine::GetOutputInGPU(const std::string& name) { | ||
| return buffer(name); | ||
| } | ||
|
|
||
| void TensorrtEngine::GetOutputInCPU(const std::string& name, void* dst, | ||
| size_t max_size) { | ||
| // determine data size | ||
| auto it = buffer_sizes_.find(name); | ||
| PADDLE_ENFORCE(it != buffer_sizes_.end()); | ||
| PADDLE_ENFORCE_GT(it->second, 0); | ||
| PADDLE_ENFORCE_GE(max_size, it->second); | ||
|
|
||
| PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second, | ||
| cudaMemcpyDeviceToHost, *stream_)); | ||
| } | ||
|
|
||
| void*& TensorrtEngine::buffer(const std::string& name) { | ||
| PADDLE_ENFORCE(infer_engine_ != nullptr, "call freezenetwork first."); | ||
|
||
| auto it = buffer_sizes_.find(name); | ||
| PADDLE_ENFORCE(it != buffer_sizes_.end()); | ||
| auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); | ||
| return buffers_[slot_offset]; | ||
| } | ||
|
|
||
| void TensorrtEngine::SetInputFromCPU(const std::string& name, void* data, | ||
| size_t size) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. size_t size也是const类型 |
||
| void* buf = buffer(name); | ||
| PADDLE_ENFORCE_EQ( | ||
| 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); | ||
| } | ||
|
|
||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <NvInfer.h> | ||
| #include <memory> | ||
| #include <unordered_map> | ||
| #include "paddle/fluid/inference/engine.h" | ||
|
|
||
| namespace paddle { | ||
|
|
||
| /* | ||
| * TensorRT Engine. | ||
| * | ||
| * There are two alternative ways to use it, one is to build from a paddle | ||
| * protobuf model, another way is to manully construct the network. | ||
| */ | ||
| class TensorrtEngine : public EngineBase { | ||
|
||
| public: | ||
| using data_type = nvinfer1::DataType; | ||
|
||
| using dim_type = nvinfer1::Dims; | ||
|
|
||
| // Weight is model parameter. | ||
| class Weight { | ||
| public: | ||
| Weight(data_type dtype, void* value, int num_elem) { | ||
|
||
| w_.type = dtype; | ||
| w_.values = value; | ||
| w_.count = num_elem; | ||
| } | ||
| const nvinfer1::Weights& get() { return w_; } | ||
|
|
||
| private: | ||
| nvinfer1::Weights w_; | ||
| }; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. weight class放在TensorrtEngine class里合适么?这个class也能被convert class调用。 |
||
|
|
||
| TensorrtEngine(int max_batch, int max_workspace, cudaStream_t* stream) | ||
| : max_batch_(max_batch), max_workspace_(max_workspace), stream_(stream) {} | ||
|
|
||
| virtual ~TensorrtEngine(); | ||
|
|
||
| // TODO(Superjomn) implement it latter when graph segmentation is supported. | ||
|
||
| virtual void Build(const PbType& paddle_model) override; | ||
|
|
||
| virtual void Execute(int batch_size) override; | ||
|
|
||
| // Initialize the infer network, so that layers can add to this network. | ||
|
||
| void InitNetwork(); | ||
| // Finished adding layers, freeze this network and creates the executation | ||
|
||
| // environment. | ||
| void FreezeNetwork(); | ||
|
|
||
| // Add an input and set its namd, data type and dimention. | ||
|
||
| nvinfer1::ITensor* DeclInput(const std::string& name, data_type dtype, | ||
| const dim_type& dim); | ||
| // Set the offset-th output from a layer as the network's output, and set its | ||
| // name. | ||
| void DeclOutput(nvinfer1::ILayer* layer, int offset, const std::string& name); | ||
|
|
||
| // GPU memory address for a tensor with specific name. One can operate on | ||
|
||
| // these memory directly for acceleration, for example, output the converted | ||
| // data directly to the buffer to save data copy overhead. | ||
| // NOTE this should be used after calling `FreezeNetwork`. | ||
| void*& buffer(const std::string& name); | ||
|
|
||
| // Fill an input from CPU memory with name and size. | ||
| void SetInputFromCPU(const std::string& name, void* data, size_t size); | ||
| // TODO(Superjomn) is this method necessary given that buffer(xxx) can be | ||
| // accessed directly. Fill an input from GPU memory with name and size. | ||
| void SetInputFromGPU(const std::string& name, void* data, size_t size); | ||
| // Get an output called name, the output of tensorrt is in GPU, so this method | ||
| // will just return the output's GPU memory address. | ||
| void* GetOutputInGPU(const std::string& name); | ||
| // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU | ||
| // to CPU. | ||
| void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); | ||
|
|
||
| nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } | ||
| nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } | ||
|
|
||
| private: | ||
| int max_batch_; | ||
| int max_workspace_; | ||
| cudaStream_t* stream_; | ||
|
|
||
| std::vector<void*> buffers_; | ||
| // max data size for the buffers. | ||
| std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_; | ||
|
|
||
| template <typename T> | ||
| struct Destroyer { | ||
| void operator()(T* x) { x->destroy(); } | ||
| }; | ||
|
|
||
| template <typename T> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 111-114行变量请加注释 |
||
| using infer_ptr = std::unique_ptr<T, Destroyer<T>>; | ||
| infer_ptr<nvinfer1::IBuilder> infer_builder_; | ||
| infer_ptr<nvinfer1::INetworkDefinition> infer_network_; | ||
| infer_ptr<nvinfer1::ICudaEngine> infer_engine_; | ||
| infer_ptr<nvinfer1::IExecutionContext> infer_context_; | ||
| }; // class TensorrtEngine | ||
|
|
||
| // Add an layer__ into engine__ with args ARGS. | ||
| // For example: | ||
| // TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias) | ||
| // | ||
| // Reference | ||
| // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network | ||
| // | ||
| // will add a fully connected layer into the engine. | ||
| // TensorRT has too many layers, so that is not wise to add member functions for | ||
| // them, and an macro like this is more extensible when underlying TensorRT | ||
| // library add new layer supports. | ||
| #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ | ||
| engine__->network()->add##layer__(ARGS); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请问这个宏定义可以去掉么?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个宏会
|
||
|
|
||
| } // namespace paddle | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright的格式调整一下吧。另外新加入的文件copyright里面的年份应该是2018年。