Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,66 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) {
#ifdef USE_TENSORRT
class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterface<std::string> {};

TEST_P(CApiTensorRTTest, TestExternalCUDAStreamWithIOBinding) {
const auto& api = Ort::GetApi();
OrtTensorRTProviderOptionsV2* trt_options;
ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr);
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions);

// updating provider option with user provided compute stream
cudaStream_t compute_stream = nullptr;
void* user_compute_stream = nullptr;
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr);
Comment thread
tianleiwu marked this conversation as resolved.
ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);

std::basic_string<ORTCHAR_T> model_uri = MODEL_URI;

Ort::SessionOptions session_options;
ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast<OrtSessionOptions*>(session_options), rel_trt_options.get()) == nullptr);

std::unique_ptr<Ort::Session> session;
session.reset(new Ort::Session(*ort_env, model_uri.c_str(), session_options));

Ort::IoBinding iobindings(*session);

// input tensor on gpu
Ort::MemoryInfo memory_info_gpu{"Cuda", OrtDeviceAllocator, 0, OrtMemTypeDefault};
void* input_tensor_data = nullptr;
int type_size = 4;
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
std::vector<int64_t> x_shape({3, 2});
size_t tensor_size = x_values.size();
ONNXTensorElementDataType type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
assert(cudaMalloc(&input_tensor_data, tensor_size*type_size) == cudaSuccess);
cudaMemcpy(input_tensor_data, x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice);
Ort::Value ort_input_tensor_value = Ort::Value::CreateTensor(memory_info_gpu, input_tensor_data, tensor_size * type_size,
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
x_shape.data(), x_shape.size(), type);

// output tensor on cpu
Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);
const std::array<int64_t, 2> y_shape = {3, 2};
std::array<float, 3 * 2> y_values;
const std::array<float, 3 * 2> expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
Ort::Value ort_output_tensor_value = Ort::Value::CreateTensor(info_cpu, y_values.data(), y_values.size(),
y_shape.data(), y_shape.size());

iobindings.BindInput("X", ort_input_tensor_value);
iobindings.BindOutput("Y", ort_output_tensor_value);

session->Run(Ort::RunOptions(), iobindings);
Comment thread
tianleiwu marked this conversation as resolved.
Outdated

// Check the values against the bound raw memory
ASSERT_TRUE(std::equal(std::begin(y_values), std::end(y_values), std::begin(expected_y)));

iobindings.ClearBoundInputs();
iobindings.ClearBoundOutputs();

cudaFree(input_tensor_data);
cudaStreamDestroy(compute_stream);
}

// This test uses CreateTensorRTProviderOptions/UpdateTensorRTProviderOptions APIs to configure and create a TensorRT Execution Provider
TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) {
std::string param = GetParam();
Expand All @@ -2849,15 +2909,6 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) {
ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr);
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions);

// Only test updating provider option with user provided compute stream
cudaStream_t compute_stream = nullptr;
void* user_compute_stream = nullptr;
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr);
ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);
cudaStreamDestroy(compute_stream);

const char* engine_cache_path = "./trt_engine_folder";

std::vector<const char*> keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable",
Expand Down