Skip to content

Commit 5fdef1c

Browse files
committed
test: [collection] update model path in test_collection.cpp
Signed-off-by: inocsin <[email protected]>
1 parent 1de4a33 commit 5fdef1c

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

tests/cpp/test_collection.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
TEST(CppAPITests, TestCollectionNormalInput) {
1010

11-
std::string path = "/root/Torch-TensorRT/normal_model.ts";
11+
std::string path = "tests/modules/normal_model.jit.pt";
1212
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
1313
std::vector<at::Tensor> inputs;
1414
inputs.push_back(in0);
@@ -53,7 +53,7 @@ TEST(CppAPITests, TestCollectionNormalInput) {
5353

5454
TEST(CppAPITests, TestCollectionTupleInput) {
5555

56-
std::string path = "/root/Torch-TensorRT/tuple_input.ts";
56+
std::string path = "tests/modules/tuple_input.jit.pt";
5757
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
5858

5959
torch::jit::Module mod;
@@ -103,7 +103,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
103103

104104
TEST(CppAPITests, TestCollectionListInput) {
105105

106-
std::string path = "/root/Torch-TensorRT/list_input.ts";
106+
std::string path = "tests/modules/list_input.jit.pt";
107107
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
108108
std::vector<at::Tensor> inputs;
109109
inputs.push_back(in0);
@@ -169,7 +169,7 @@ TEST(CppAPITests, TestCollectionListInput) {
169169

170170
TEST(CppAPITests, TestCollectionTupleInputOutput) {
171171

172-
std::string path = "/root/Torch-TensorRT/tuple_input_output.ts";
172+
std::string path = "tests/modules/tuple_input_output.jit.pt";
173173

174174
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
175175

@@ -224,7 +224,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
224224

225225
TEST(CppAPITests, TestCollectionListInputOutput) {
226226

227-
std::string path = "/root/Torch-TensorRT/list_input_output.ts";
227+
std::string path = "tests/modules/list_input_output.jit.pt";
228228
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
229229
std::vector<at::Tensor> inputs;
230230
inputs.push_back(in0);
@@ -296,7 +296,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
296296

297297
TEST(CppAPITests, TestCollectionComplexModel) {
298298

299-
std::string path = "/root/Torch-TensorRT/complex_model.ts";
299+
std::string path = "tests/modules/complex_model.jit.pt";
300300
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
301301
std::vector<at::Tensor> inputs;
302302
inputs.push_back(in0);

tests/modules/hub.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,29 +256,29 @@ def forward(self, z: List[torch.Tensor]):
256256
normal_model = Normal()
257257
normal_model_ts = torch.jit.script(normal_model)
258258
normal_model_ts.to("cuda").eval()
259-
torch.jit.save(normal_model_ts, "normal_model.ts")
259+
torch.jit.save(normal_model_ts, "normal_model.jit.pt")
260260

261261
tuple_input = TupleInput()
262262
tuple_input_ts = torch.jit.script(tuple_input)
263263
tuple_input_ts.to("cuda").eval()
264-
torch.jit.save(tuple_input_ts, "tuple_input.ts")
264+
torch.jit.save(tuple_input_ts, "tuple_input.jit.pt")
265265

266266
list_input = ListInput()
267267
list_input_ts = torch.jit.script(list_input)
268268
list_input_ts.to("cuda").eval()
269-
torch.jit.save(list_input_ts, "list_input.ts")
269+
torch.jit.save(list_input_ts, "list_input.jit.pt")
270270

271271
tuple_input = TupleInputOutput()
272272
tuple_input_ts = torch.jit.script(tuple_input)
273273
tuple_input_ts.to("cuda").eval()
274-
torch.jit.save(tuple_input_ts, "tuple_input_output.ts")
274+
torch.jit.save(tuple_input_ts, "tuple_input_output.jit.pt")
275275

276276
list_input = ListInputOutput()
277277
list_input_ts = torch.jit.script(list_input)
278278
list_input_ts.to("cuda").eval()
279-
torch.jit.save(list_input_ts, "list_input_output.ts")
279+
torch.jit.save(list_input_ts, "list_input_output.jit.pt")
280280

281281
complex_model = ComplexModel()
282282
complex_model_ts = torch.jit.script(complex_model)
283283
complex_model_ts.to("cuda").eval()
284-
torch.jit.save(complex_model_ts, "complex_model.ts")
284+
torch.jit.save(complex_model_ts, "complex_model.jit.pt")

0 commit comments

Comments
 (0)