Skip to content

Commit 5ae6942

Browse files
committed
feat: Save target platform as part of TRTEngine Metadata
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4aa6e79 commit 5ae6942

File tree

17 files changed

+423
-108
lines changed

17 files changed

+423
-108
lines changed

core/runtime/BUILD

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
name = "runtime",
2222
srcs = [
2323
"DeviceList.cpp",
24+
"Platform.cpp",
2425
"RTDevice.cpp",
2526
"TRTEngine.cpp",
2627
"TRTEngineProfiler.cpp",
@@ -29,6 +30,7 @@ cc_library(
2930
"runtime.cpp",
3031
],
3132
hdrs = [
33+
"Platform.h",
3234
"RTDevice.h",
3335
"TRTEngine.h",
3436
"TRTEngineProfiler.h",
@@ -41,16 +43,26 @@ cc_library(
4143
"//core/plugins:torch_tensorrt_plugins",
4244
"//core/util:prelude",
4345
] + select({
44-
":windows": ["@tensorrt_win//:nvinfer", "@libtorch_win//:libtorch"],
45-
":use_pre_cxx11_abi": ["@tensorrt//:nvinfer", "@libtorch_pre_cxx11_abi//:libtorch"],
46-
"//conditions:default": ["@tensorrt//:nvinfer", "@libtorch"],
46+
":use_pre_cxx11_abi": [
47+
"@libtorch_pre_cxx11_abi//:libtorch",
48+
"@tensorrt//:nvinfer",
49+
],
50+
":windows": [
51+
"@libtorch_win//:libtorch",
52+
"@tensorrt_win//:nvinfer",
53+
],
54+
"//conditions:default": [
55+
"@libtorch",
56+
"@tensorrt//:nvinfer",
57+
],
4758
}),
4859
alwayslink = True,
4960
)
5061

5162
pkg_tar(
5263
name = "include",
5364
srcs = [
65+
"Platform.h",
5466
"RTDevice.h",
5567
"TRTEngine.h",
5668
"TRTEngineProfiler.h",

core/runtime/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ set(CXX_SRCS
99
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
1010
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
1111
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
12+
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.cpp"
1213
)
1314

1415
set(HEADER_FILES
1516
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
1617
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
1718
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
1819
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
20+
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.h"
1921
)
2022

2123
target_sources(${lib_name}

core/runtime/Platform.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "core/runtime/Platform.h"
2+
#include "core/runtime/runtime.h"
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace runtime {
8+
9+
namespace {
10+
const std::unordered_map<std::string, Platform::PlatformEnum>& get_name_to_platform_map() {
11+
static const std::unordered_map<std::string, Platform::PlatformEnum> name_to_platform_map = {
12+
{"linux_aarch64", Platform::PlatformEnum::kLINUX_AARCH64},
13+
{"linux_x86_64", Platform::PlatformEnum::kLINUX_X86_64},
14+
{"windows_x86_64", Platform::PlatformEnum::kWIN_X86_64},
15+
{"unknown", Platform::PlatformEnum::kUNKNOWN},
16+
};
17+
return name_to_platform_map;
18+
}
19+
20+
const std::unordered_map<Platform::PlatformEnum, std::string>& _get_platform_name_map() {
21+
static const std::unordered_map<Platform::PlatformEnum, std::string> platform_name_map = {
22+
{Platform::PlatformEnum::kLINUX_AARCH64, "linux_aarch64"},
23+
{Platform::PlatformEnum::kLINUX_X86_64, "linux_x86_64"},
24+
{Platform::PlatformEnum::kWIN_X86_64, "windows_x86_64"},
25+
{Platform::PlatformEnum::kUNKNOWN, "unknown"}};
26+
return platform_name_map;
27+
}
28+
} // namespace
29+
30+
const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map() {
31+
return _get_platform_name_map();
32+
}
33+
34+
Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}
35+
36+
Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}
37+
38+
Platform::Platform(const std::string& platform_str) {
39+
LOG_ERROR("Platform constructor: " << platform_str);
40+
auto name_map = get_name_to_platform_map();
41+
auto it = name_map.find(platform_str);
42+
if (it != name_map.end()) {
43+
_platform = it->second;
44+
} else {
45+
LOG_WARNING("Unknown platform " << platform_str);
46+
_platform = Platform::PlatformEnum::kUNKNOWN;
47+
}
48+
}
49+
50+
std::string Platform::serialize() const {
51+
auto name_map = get_platform_name_map();
52+
auto it = name_map.find(_platform);
53+
if (it != name_map.end()) {
54+
return it->second;
55+
} else {
56+
LOG_WARNING("Attempted to serialized unknown platform tag");
57+
return std::string("unknown");
58+
}
59+
}
60+
61+
Platform& Platform::operator=(const Platform& other) {
62+
_platform = other._platform;
63+
return (*this);
64+
}
65+
66+
bool operator==(const Platform& lhs, const Platform& rhs) {
67+
return lhs._platform == rhs._platform;
68+
}
69+
70+
std::ostream& operator<<(std::ostream& os, const Platform& platform) {
71+
os << platform.serialize();
72+
return os;
73+
}
74+
75+
Platform get_current_platform() {
76+
#if defined(__linux__) || defined(__gnu_linux__)
77+
#if defined(__aarch64__)
78+
return Platform(Platform::PlatformEnum::kLINUX_AARCH64);
79+
#elif defined(__amd64__) || defined(__x86_64__)
80+
return Platform(Platform::PlatformEnum::kLINUX_X86_64);
81+
#else
82+
return Platform(Platform::PlatformEnum::kUNKNOWN);
83+
#endif
84+
#elif defined(_WIN32) || defined(_WIN64)
85+
#if defined(_M_AMD64) || defined(_M_X64)
86+
return Platform(Platform::PlatformEnum::kWIN_X86_64);
87+
#else
88+
return Platform(Platform::PlatformEnum::kUNKNOWN);
89+
#endif
90+
#else
91+
return Platform(Platform::PlatformEnum::kUNKNOWN);
92+
#endif
93+
}
94+
95+
bool is_supported_on_current_platform(Platform target) {
96+
// Space for more complicated platform support calculations later
97+
return target == get_current_platform();
98+
}
99+
100+
} // namespace runtime
101+
} // namespace core
102+
} // namespace torch_tensorrt

core/runtime/Platform.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
#include <string>
3+
#include <unordered_map>
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace runtime {
8+
9+
struct Platform {
10+
typedef enum {
11+
kLINUX_X86_64 = 0,
12+
kLINUX_AARCH64,
13+
kWIN_X86_64,
14+
kUNKNOWN,
15+
} PlatformEnum;
16+
17+
PlatformEnum _platform = Platform::kUNKNOWN;
18+
19+
Platform();
20+
Platform(PlatformEnum val);
21+
Platform(const std::string& platform_str);
22+
std::string serialize() const;
23+
Platform& operator=(const Platform& other);
24+
25+
friend std::ostream& operator<<(std::ostream& os, const Platform& device);
26+
friend bool operator==(const Platform& lhs, const Platform& rhs);
27+
};
28+
29+
const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map();
30+
Platform get_current_platform();
31+
bool is_supported_on_current_platform(Platform target);
32+
33+
} // namespace runtime
34+
} // namespace core
35+
} // namespace torch_tensorrt

core/runtime/TRTEngine.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ TRTEngine::TRTEngine(
3434
const RTDevice& cuda_device,
3535
const std::vector<std::string>& _in_binding_names,
3636
const std::vector<std::string>& _out_binding_names,
37+
const Platform& target_platform,
3738
bool hardware_compatible,
3839
const std::string& serialized_metadata)
3940
: TRTEngine(
@@ -42,6 +43,7 @@ TRTEngine::TRTEngine(
4243
cuda_device,
4344
_in_binding_names,
4445
_out_binding_names,
46+
target_platform,
4547
hardware_compatible,
4648
serialized_metadata) {}
4749

@@ -52,6 +54,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
5254
RTDevice(serialized_info[DEVICE_IDX]),
5355
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
5456
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
57+
Platform(serialized_info[TARGET_PLATFORM_IDX]),
5558
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
5659
serialized_info[SERIALIZED_METADATA_IDX]) {}
5760

@@ -61,12 +64,22 @@ TRTEngine::TRTEngine(
6164
const RTDevice& cuda_device,
6265
const std::vector<std::string>& _in_binding_names,
6366
const std::vector<std::string>& _out_binding_names,
67+
const Platform& target_platform,
6468
bool hardware_compatible,
6569
const std::string& serialized_metadata) {
70+
TORCHTRT_CHECK(
71+
is_supported_on_current_platform(target_platform),
72+
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
73+
<< get_current_platform() << ")");
74+
this->target_platform = target_platform;
75+
76+
this->cudagraph_mempool_id = at::cuda::graph_pool_handle();
77+
6678
this->hardware_compatible = hardware_compatible;
67-
this->serialized_metadata = serialized_metadata;
6879
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
6980
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
81+
82+
this->serialized_metadata = serialized_metadata;
7083
device_info = most_compatible_device.value();
7184
multi_gpu_device_check();
7285
set_rt_device(device_info);
@@ -196,7 +209,6 @@ TRTEngine::TRTEngine(
196209
}
197210

198211
TRTEngine::~TRTEngine() {
199-
cudagraph.reset();
200212
trt_engine_profiler.reset();
201213
exec_ctx.reset();
202214
cuda_engine.reset();
@@ -276,6 +288,7 @@ std::string TRTEngine::to_str() const {
276288
ss << " ]" << std::endl;
277289
ss << " Device: " << device_info << std::endl;
278290
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
291+
ss << " Target Platform: " << target_platform << std::endl;
279292
// clang-format on
280293
return ss.str();
281294
}

core/runtime/TRTEngine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,30 @@ struct TRTEngine : torch::CustomClassHolder {
3939
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
4040
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
4141
// in compilation
42+
Platform target_platform;
4243

4344
~TRTEngine();
4445
TRTEngine(
4546
const std::string& serialized_engine,
4647
const RTDevice& cuda_device,
4748
const std::vector<std::string>& in_binding_names,
4849
const std::vector<std::string>& out_binding_names,
50+
const Platform& target_platform = get_current_platform(),
4951
bool hardware_compatible = false,
5052
const std::string& serialized_metadata = "");
53+
5154
TRTEngine(std::vector<std::string> serialized_info);
55+
5256
TRTEngine(
5357
const std::string& mod_name,
5458
const std::string& serialized_engine,
5559
const RTDevice& cuda_device,
5660
const std::vector<std::string>& in_binding_names,
5761
const std::vector<std::string>& out_binding_names,
62+
const Platform& target_platform = get_current_platform(),
5863
bool hardware_compatible = false,
5964
const std::string& serialized_metadata = "");
65+
6066
TRTEngine& operator=(const TRTEngine& other);
6167
std::string to_str() const;
6268
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
@@ -75,6 +81,7 @@ struct TRTEngine : torch::CustomClassHolder {
7581
std::vector<at::Tensor> input_buffers = {};
7682
std::vector<at::Tensor> output_buffers = {};
7783
std::string shape_key;
84+
at::cuda::MempoolId_t cudagraph_mempool_id;
7885

7986
// TODO: Implement a call method
8087
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
328328
if (need_cudagraphs_record) {
329329
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
330330
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
331-
compiled_engine->cudagraph.capture_begin();
331+
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
332332
compiled_engine->exec_ctx->enqueueV3(recording_stream);
333333
compiled_engine->cudagraph.capture_end();
334334

core/runtime/register_jit_hooks.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <codecvt>
22

3+
#include "core/runtime/Platform.h"
34
#include "core/runtime/runtime.h"
5+
#include "core/util/macros.h"
46

57
namespace torch_tensorrt {
68
namespace core {
@@ -103,11 +105,14 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
103105
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
104106
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
105107
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
108+
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
106109
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
110+
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);
107111

108112
return serialize_info;
109113
},
110114
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
115+
LOG_ERROR(serialized_info[TARGET_PLATFORM_IDX]);
111116
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
112117
TRTEngine::verify_serialization_fmt(serialized_info);
113118
return c10::make_intrusive<TRTEngine>(serialized_info);
@@ -137,7 +142,28 @@ TORCH_LIBRARY(tensorrt, m) {
137142
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
138143
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
139144
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
145+
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
140146
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
147+
m.def("_platform_linux_x86_64", []() -> std::string {
148+
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
149+
return it->second;
150+
});
151+
m.def("_platform_linux_aarch64", []() -> std::string {
152+
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_AARCH64);
153+
return it->second;
154+
});
155+
m.def("_platform_win_x86_64", []() -> std::string {
156+
auto it = get_platform_name_map().find(Platform::PlatformEnum::kWIN_X86_64);
157+
return it->second;
158+
});
159+
m.def("_platform_unknown", []() -> std::string {
160+
auto it = get_platform_name_map().find(Platform::PlatformEnum::kUNKNOWN);
161+
return it->second;
162+
});
163+
m.def("get_current_platform", []() -> std::string {
164+
auto it = get_platform_name_map().find(get_current_platform()._platform);
165+
return it->second;
166+
});
141167
}
142168

143169
} // namespace

core/runtime/runtime.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <utility>
66
#include "ATen/core/function_schema.h"
77
#include "NvInfer.h"
8+
#include "core/runtime/Platform.h"
89
#include "core/runtime/RTDevice.h"
910
#include "core/runtime/TRTEngine.h"
1011
#include "core/util/prelude.h"
@@ -15,7 +16,7 @@ namespace core {
1516
namespace runtime {
1617

1718
using EngineID = int64_t;
18-
const std::string ABI_VERSION = "5";
19+
const std::string ABI_VERSION = "6";
1920
extern bool MULTI_DEVICE_SAFE_MODE;
2021
extern bool CUDAGRAPHS_MODE;
2122

@@ -28,6 +29,7 @@ typedef enum {
2829
OUTPUT_BINDING_NAMES_IDX,
2930
HW_COMPATIBLE_IDX,
3031
SERIALIZED_METADATA_IDX,
32+
TARGET_PLATFORM_IDX,
3133
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
3234
} SerializedInfoIndex;
3335

@@ -47,7 +49,7 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode);
4749

4850
bool get_cudagraphs_mode();
4951

50-
void set_cudagraphs_mode(bool multi_device_safe_mode);
52+
void set_cudagraphs_mode(bool cudagraphs_mode);
5153

5254
class DeviceList {
5355
using DeviceMap = std::unordered_map<int, RTDevice>;

0 commit comments

Comments
 (0)