Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_));
infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the executation
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/inference/tensorrt/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}

// A logger for create TensorRT infer builder.
Expand Down Expand Up @@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
return *x;
}

virtual ~NaiveLogger() override {}
~NaiveLogger() override {}
};

} // namespace tensorrt
Expand Down
26 changes: 13 additions & 13 deletions paddle/fluid/inference/tensorrt/test_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"

namespace dy = paddle::platform::dynload;
Expand All @@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {

class ScopedWeights {
public:
ScopedWeights(float value) : value_(value) {
explicit ScopedWeights(float value) : value_(value) {
w.type = nvinfer1::DataType::kFLOAT;
w.values = &value_;
w.count = 1;
Expand All @@ -58,13 +58,13 @@ class ScopedWeights {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
}
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}

const char* kInputTensor = "input";
Expand All @@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
nvinfer1::IHostMemory* CreateNetwork() {
Logger logger;
// Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(logger);
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
ScopedWeights weights(2.);
ScopedWeights bias(3.);

Expand Down Expand Up @@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
return model;
}

void Execute(nvinfer1::IExecutionContext& context, const float* input,
void Execute(nvinfer1::IExecutionContext* context, const float* input,
float* output) {
const nvinfer1::ICudaEngine& engine = context.getEngine();
const nvinfer1::ICudaEngine& engine = context->getEngine();
// Two binds, input and output
ASSERT_EQ(engine.getNbBindings(), 2);
const int input_index = engine.getBindingIndex(kInputTensor);
Expand All @@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
// Copy the input to the GPU, execute the network, and copy the output back.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
context.enqueue(1, buffers, stream, nullptr);
context->enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
Expand All @@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {

// Use the model to create an engine and an execution context.
Logger logger;
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
nvinfer1::IRuntime* runtime = createInferRuntime(&logger);
nvinfer1::ICudaEngine* engine =
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
model->destroy();
Expand All @@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
// Execute the network.
float input = 1234;
float output;
Execute(*context, &input, &output);
Execute(context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);

// Destroy the engine.
Expand Down