Skip to content

Commit decd728

Browse files
committed
Moved internal Device type to ir + updated BUILD files and imports
1 parent 960220a commit decd728

File tree

6 files changed

+14
-11
lines changed

6 files changed

+14
-11
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
@@ -22,7 +23,7 @@ struct BuilderSettings {
2223
bool refit = false;
2324
bool debug = false;
2425
bool truncate_long_and_double = false;
25-
torch_tensorrt::core::util::Device device;
26+
ir::Device device;
2627
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
2728
nvinfer1::IInt8Calibrator* calibrator = nullptr;
2829
uint64_t num_avg_timing_iters = 1;

core/ir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace ir {
1313

14+
struct Device {
15+
nvinfer1::DeviceType device_type;
16+
int64_t gpu_id;
17+
int64_t dla_core;
18+
bool allow_gpu_fallback;
19+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
20+
};
21+
1422
struct Input : torch::CustomClassHolder {
1523
Input(){};
1624
Input(

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/lowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
#include <memory>
3-
#include "core/util/prelude.h"
3+
#include "core/ir/ir.h"
44
#include "torch/csrc/jit/ir/ir.h"
55

66
namespace torch_tensorrt {
@@ -16,7 +16,7 @@ struct LowerInfo {
1616
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1717
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1818
bool disable_cse = false;
19-
torch_tensorrt::core::util::Device target_device;
19+
ir::Device target_device;
2020
std::vector<std::string> forced_fallback_modules;
2121
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
2222

core/util/trt_util.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,6 @@ namespace torch_tensorrt {
127127
namespace core {
128128
namespace util {
129129

130-
struct Device {
131-
nvinfer1::DeviceType device_type;
132-
int64_t gpu_id;
133-
int64_t dla_core;
134-
bool allow_gpu_fallback;
135-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
136-
};
137-
138130
int64_t volume(const nvinfer1::Dims& d);
139131

140132
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional = true);

0 commit comments

Comments
 (0)