Skip to content

feat: Save target platform as part of TRTEngine Metadata #3106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
name = "runtime",
srcs = [
"DeviceList.cpp",
"Platform.cpp",
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
Expand All @@ -29,6 +30,7 @@ cc_library(
"runtime.cpp",
],
hdrs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand All @@ -41,16 +43,26 @@ cc_library(
"//core/plugins:torch_tensorrt_plugins",
"//core/util:prelude",
] + select({
":windows": ["@tensorrt_win//:nvinfer", "@libtorch_win//:libtorch"],
":use_pre_cxx11_abi": ["@tensorrt//:nvinfer", "@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@tensorrt//:nvinfer", "@libtorch"],
":use_pre_cxx11_abi": [
"@libtorch_pre_cxx11_abi//:libtorch",
"@tensorrt//:nvinfer",
],
":windows": [
"@libtorch_win//:libtorch",
"@tensorrt_win//:nvinfer",
],
"//conditions:default": [
"@libtorch",
"@tensorrt//:nvinfer",
],
}),
alwayslink = True,
)

pkg_tar(
name = "include",
srcs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.cpp"
)

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.h"
)

target_sources(${lib_name}
Expand Down
102 changes: 102 additions & 0 deletions core/runtime/Platform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

namespace {
const std::unordered_map<std::string, Platform::PlatformEnum>& get_name_to_platform_map() {
static const std::unordered_map<std::string, Platform::PlatformEnum> name_to_platform_map = {
{"linux_aarch64", Platform::PlatformEnum::kLINUX_AARCH64},
{"linux_x86_64", Platform::PlatformEnum::kLINUX_X86_64},
{"windows_x86_64", Platform::PlatformEnum::kWIN_X86_64},
{"unknown", Platform::PlatformEnum::kUNKNOWN},
};
return name_to_platform_map;
}

const std::unordered_map<Platform::PlatformEnum, std::string>& _get_platform_name_map() {
static const std::unordered_map<Platform::PlatformEnum, std::string> platform_name_map = {
{Platform::PlatformEnum::kLINUX_AARCH64, "linux_aarch64"},
{Platform::PlatformEnum::kLINUX_X86_64, "linux_x86_64"},
{Platform::PlatformEnum::kWIN_X86_64, "windows_x86_64"},
{Platform::PlatformEnum::kUNKNOWN, "unknown"}};
return platform_name_map;
}
} // namespace

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map() {
return _get_platform_name_map();
}

Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}

Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}

Platform::Platform(const std::string& platform_str) {
LOG_ERROR("Platform constructor: " << platform_str);
auto name_map = get_name_to_platform_map();
auto it = name_map.find(platform_str);
if (it != name_map.end()) {
_platform = it->second;
} else {
LOG_WARNING("Unknown platform " << platform_str);
_platform = Platform::PlatformEnum::kUNKNOWN;
}
}

std::string Platform::serialize() const {
auto name_map = get_platform_name_map();
auto it = name_map.find(_platform);
if (it != name_map.end()) {
return it->second;
} else {
LOG_WARNING("Attempted to serialized unknown platform tag");
return std::string("unknown");
}
}

Platform& Platform::operator=(const Platform& other) {
_platform = other._platform;
return (*this);
}

bool operator==(const Platform& lhs, const Platform& rhs) {
return lhs._platform == rhs._platform;
}

std::ostream& operator<<(std::ostream& os, const Platform& platform) {
os << platform.serialize();
return os;
}

Platform get_current_platform() {
#if defined(__linux__) || defined(__gnu_linux__)
#if defined(__aarch64__)
return Platform(Platform::PlatformEnum::kLINUX_AARCH64);
#elif defined(__amd64__) || defined(__x86_64__)
return Platform(Platform::PlatformEnum::kLINUX_X86_64);
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
#elif defined(_WIN32) || defined(_WIN64)
#if defined(_M_AMD64) || defined(_M_X64)
return Platform(Platform::PlatformEnum::kWIN_X86_64);
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
}

bool is_supported_on_current_platform(Platform target) {
// Space for more complicated platform support calculations later
return target == get_current_platform();
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
35 changes: 35 additions & 0 deletions core/runtime/Platform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <unordered_map>

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct Platform {
typedef enum {
kLINUX_X86_64 = 0,
kLINUX_AARCH64,
kWIN_X86_64,
kUNKNOWN,
} PlatformEnum;

PlatformEnum _platform = Platform::kUNKNOWN;

Platform();
Platform(PlatformEnum val);
Platform(const std::string& platform_str);
std::string serialize() const;
Platform& operator=(const Platform& other);

friend std::ostream& operator<<(std::ostream& os, const Platform& device);
friend bool operator==(const Platform& lhs, const Platform& rhs);
};

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map();
Platform get_current_platform();
bool is_supported_on_current_platform(Platform target);

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 15 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata)
: TRTEngine(
Expand All @@ -42,6 +43,7 @@ TRTEngine::TRTEngine(
cuda_device,
_in_binding_names,
_out_binding_names,
target_platform,
hardware_compatible,
serialized_metadata) {}

Expand All @@ -52,6 +54,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

Expand All @@ -61,12 +64,22 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
<< get_current_platform() << ")");
this->target_platform = target_platform;

this->cudagraph_mempool_id = at::cuda::graph_pool_handle();

this->hardware_compatible = hardware_compatible;
this->serialized_metadata = serialized_metadata;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");

this->serialized_metadata = serialized_metadata;
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);
Expand Down Expand Up @@ -196,7 +209,6 @@ TRTEngine::TRTEngine(
}

TRTEngine::~TRTEngine() {
cudagraph.reset();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The object destructor calls it and its potentially unsafe to do it ourselves

trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
Expand Down Expand Up @@ -276,6 +288,7 @@ std::string TRTEngine::to_str() const {
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
// clang-format on
return ss.str();
}
Expand Down
7 changes: 7 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,30 @@ struct TRTEngine : torch::CustomClassHolder {
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation
Platform target_platform;

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine(std::vector<std::string> serialized_info);

TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand All @@ -75,6 +81,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
at::cuda::MempoolId_t cudagraph_mempool_id;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
compiled_engine->cudagraph.capture_begin();
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
compiled_engine->exec_ctx->enqueueV3(recording_stream);
compiled_engine->cudagraph.capture_end();

Expand Down
26 changes: 26 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <codecvt>

#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/macros.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -103,11 +105,14 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);

return serialize_info;
},
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
LOG_ERROR(serialized_info[TARGET_PLATFORM_IDX]);
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
Expand Down Expand Up @@ -137,7 +142,28 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
});
m.def("_platform_linux_aarch64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_AARCH64);
return it->second;
});
m.def("_platform_win_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kWIN_X86_64);
return it->second;
});
m.def("_platform_unknown", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kUNKNOWN);
return it->second;
});
m.def("get_current_platform", []() -> std::string {
auto it = get_platform_name_map().find(get_current_platform()._platform);
return it->second;
});
}

} // namespace
Expand Down
6 changes: 4 additions & 2 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
#include "core/runtime/Platform.h"
#include "core/runtime/RTDevice.h"
#include "core/runtime/TRTEngine.h"
#include "core/util/prelude.h"
Expand All @@ -15,7 +16,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "5";
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

Expand All @@ -28,6 +29,7 @@ typedef enum {
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZED_METADATA_IDX,
TARGET_PLATFORM_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand All @@ -47,7 +49,7 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_cudagraphs_mode();

void set_cudagraphs_mode(bool multi_device_safe_mode);
void set_cudagraphs_mode(bool cudagraphs_mode);

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
Expand Down
Loading
Loading