6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < gtest/gtest.h>
10
-
11
9
#include < executorch/extension/module/module.h>
10
+
12
11
#include < array>
12
+ #include < thread>
13
+
14
+ #include < gtest/gtest.h>
15
+
16
+ #include < executorch/extension/data_loader/file_data_loader.h>
13
17
14
18
using namespace ::testing;
15
19
16
20
namespace torch ::executor {
17
21
18
- class ModuleTest : public ::testing::Test {};
22
+ class ModuleTest : public ::testing::Test {
23
+ protected:
24
+ static void SetUpTestSuite () {
25
+ model_path_ = std::getenv (" RESOURCES_PATH" ) + std::string (" /model.pte" );
26
+ }
27
+
28
+ static std::string model_path_;
29
+ };
30
+
31
+ std::string ModuleTest::model_path_;
19
32
20
33
TEST_F (ModuleTest, TestLoad) {
21
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
34
+ Module module (model_path_ );
22
35
23
36
EXPECT_FALSE (module.is_loaded ());
24
37
const auto error = module.load ();
@@ -43,7 +56,7 @@ TEST_F(ModuleTest, TestLoadCorruptedFile) {
43
56
}
44
57
45
58
TEST_F (ModuleTest, TestMethodNames) {
46
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
59
+ Module module (model_path_ );
47
60
48
61
const auto method_names = module.method_names ();
49
62
EXPECT_TRUE (method_names.ok ());
@@ -58,7 +71,7 @@ TEST_F(ModuleTest, TestNonExistentMethodNames) {
58
71
}
59
72
60
73
TEST_F (ModuleTest, TestLoadMethod) {
61
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
74
+ Module module (model_path_ );
62
75
63
76
EXPECT_FALSE (module.is_method_loaded (" forward" ));
64
77
const auto error = module.load_method (" forward" );
@@ -68,7 +81,7 @@ TEST_F(ModuleTest, TestLoadMethod) {
68
81
}
69
82
70
83
TEST_F (ModuleTest, TestLoadNonExistentMethod) {
71
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
84
+ Module module (model_path_ );
72
85
73
86
const auto error = module.load_method (" backward" );
74
87
EXPECT_NE (error, Error::Ok);
@@ -77,7 +90,7 @@ TEST_F(ModuleTest, TestLoadNonExistentMethod) {
77
90
}
78
91
79
92
TEST_F (ModuleTest, TestMethodMeta) {
80
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
93
+ Module module (model_path_ );
81
94
82
95
const auto meta = module.method_meta (" forward" );
83
96
EXPECT_TRUE (meta.ok ());
@@ -109,7 +122,7 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
109
122
}
110
123
111
124
TEST_F (ModuleTest, TestExecute) {
112
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
125
+ Module module (model_path_ );
113
126
114
127
std::array<float , 2 > input{1 , 2 };
115
128
std::array<int32_t , 2 > sizes{1 , 2 };
@@ -127,7 +140,7 @@ TEST_F(ModuleTest, TestExecute) {
127
140
}
128
141
129
142
TEST_F (ModuleTest, TestExecutePreload) {
130
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
143
+ Module module (model_path_ );
131
144
132
145
const auto error = module.load ();
133
146
EXPECT_EQ (error, Error::Ok);
@@ -146,7 +159,7 @@ TEST_F(ModuleTest, TestExecutePreload) {
146
159
}
147
160
148
161
TEST_F (ModuleTest, TestExecutePreload_method) {
149
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
162
+ Module module (model_path_ );
150
163
151
164
const auto error = module.load_method (" forward" );
152
165
EXPECT_EQ (error, Error::Ok);
@@ -165,7 +178,7 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
165
178
}
166
179
167
180
TEST_F (ModuleTest, TestExecutePreloadProgramAndMethod) {
168
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
181
+ Module module (model_path_ );
169
182
170
183
const auto load_error = module.load ();
171
184
EXPECT_EQ (load_error, Error::Ok);
@@ -203,7 +216,7 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
203
216
}
204
217
205
218
TEST_F (ModuleTest, TestGet) {
206
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
219
+ Module module (model_path_ );
207
220
208
221
std::array<float , 2 > input{1 , 2 };
209
222
std::array<int32_t , 2 > sizes{1 , 2 };
@@ -218,8 +231,7 @@ TEST_F(ModuleTest, TestGet) {
218
231
}
219
232
220
233
TEST_F (ModuleTest, TestForward) {
221
- auto module = std::make_unique<Module>(
222
- std::getenv (" RESOURCES_PATH" ) + std::string (" /model.pte" ));
234
+ auto module = std::make_unique<Module>(model_path_);
223
235
224
236
std::array<float , 2 > input{1 , 2 };
225
237
std::array<int32_t , 2 > sizes{1 , 2 };
@@ -244,11 +256,153 @@ TEST_F(ModuleTest, TestForward) {
244
256
}
245
257
246
258
TEST_F (ModuleTest, TestForwardWithInvalidInputs) {
247
- Module module (std::getenv ( " RESOURCES_PATH " ) + std::string ( " /model.pte " ) );
259
+ Module module (model_path_ );
248
260
249
261
const auto result = module.forward ({EValue ()});
250
262
251
263
EXPECT_FALSE (result.ok ());
252
264
}
253
265
266
+ TEST_F (ModuleTest, TestProgramSharingBetweenModules) {
267
+ Module module1 (model_path_);
268
+ EXPECT_FALSE (module1.is_loaded ());
269
+
270
+ auto load_error = module1.load ();
271
+ EXPECT_EQ (load_error, Error::Ok);
272
+ EXPECT_TRUE (module1.is_loaded ());
273
+
274
+ Module module2 (module1.program ());
275
+ EXPECT_TRUE (module2.is_loaded ());
276
+
277
+ auto method_names1 = module1.method_names ();
278
+ EXPECT_TRUE (method_names1.ok ());
279
+
280
+ auto method_names2 = module2.method_names ();
281
+ EXPECT_TRUE (method_names2.ok ());
282
+ EXPECT_EQ (method_names1.get (), method_names2.get ());
283
+
284
+ auto load_method_error = module1.load_method (" forward" );
285
+ EXPECT_EQ (load_method_error, Error::Ok);
286
+ EXPECT_TRUE (module1.is_method_loaded (" forward" ));
287
+ EXPECT_FALSE (module2.is_method_loaded (" forward" ));
288
+
289
+ auto load_method_error2 = module2.load_method (" forward" );
290
+ EXPECT_EQ (load_method_error2, Error::Ok);
291
+ EXPECT_TRUE (module2.is_method_loaded (" forward" ));
292
+ }
293
+
294
+ TEST_F (ModuleTest, TestProgramSharingAndDataLoaderManagement) {
295
+ auto loader = util::FileDataLoader::from (model_path_.c_str ());
296
+ EXPECT_TRUE (loader.ok ());
297
+ auto data_loader =
298
+ std::make_unique<util::FileDataLoader>(std::move (loader.get ()));
299
+
300
+ auto module1 = std::make_unique<Module>(std::move (data_loader));
301
+
302
+ auto load_error = module1->load ();
303
+ EXPECT_EQ (load_error, Error::Ok);
304
+ EXPECT_TRUE (module1->is_loaded ());
305
+
306
+ std::array<float , 2 > input{1 , 2 };
307
+ std::array<int32_t , 2 > sizes{1 , 2 };
308
+ TensorImpl tensor (
309
+ ScalarType::Float, sizes.size (), sizes.data (), input.data ());
310
+
311
+ auto result1 = module1->execute (" forward" , {EValue (Tensor (&tensor))});
312
+ EXPECT_TRUE (result1.ok ());
313
+
314
+ auto module2 = std::make_unique<Module>(module1->program ());
315
+
316
+ auto result2 = module2->execute (" forward" , {EValue (Tensor (&tensor))});
317
+ EXPECT_TRUE (result2.ok ());
318
+
319
+ module1 = std::make_unique<Module>(" /path/to/nonexistent/file.pte" );
320
+ EXPECT_FALSE (module1->is_loaded ());
321
+
322
+ auto result3 = module2->execute (" forward" , {EValue (Tensor (&tensor))});
323
+ EXPECT_TRUE (result3.ok ());
324
+ }
325
+
326
+ TEST_F (ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
327
+ std::shared_ptr<Program> shared_program;
328
+
329
+ {
330
+ auto loader = util::FileDataLoader::from (model_path_.c_str ());
331
+ EXPECT_TRUE (loader.ok ());
332
+ auto data_loader =
333
+ std::make_unique<util::FileDataLoader>(std::move (loader.get ()));
334
+ auto * data_loader_ptr = data_loader.get ();
335
+
336
+ Module module (std::move (data_loader));
337
+
338
+ auto load_error = module.load ();
339
+ EXPECT_EQ (load_error, Error::Ok);
340
+ EXPECT_TRUE (module.is_loaded ());
341
+
342
+ shared_program = module.program ();
343
+ EXPECT_NE (shared_program, nullptr );
344
+
345
+ EXPECT_NE (data_loader_ptr, nullptr );
346
+ }
347
+
348
+ EXPECT_NE (shared_program, nullptr );
349
+
350
+ Module module (shared_program);
351
+
352
+ EXPECT_EQ (module.program (), shared_program);
353
+
354
+ std::array<float , 2 > input{1 , 2 };
355
+ std::array<int32_t , 2 > sizes{1 , 2 };
356
+ TensorImpl tensor (
357
+ ScalarType::Float, sizes.size (), sizes.data (), input.data ());
358
+
359
+ auto result = module.execute (" forward" , {EValue (Tensor (&tensor))});
360
+ EXPECT_TRUE (result.ok ());
361
+
362
+ auto data = result->at (0 ).toTensor ().const_data_ptr <float >();
363
+
364
+ EXPECT_NEAR (data[0 ], 1.5 , 1e-5 );
365
+ }
366
+
367
+ TEST_F (ModuleTest, TestConcurrentExecutionWithSharedProgram) {
368
+ std::shared_ptr<Program> program;
369
+ {
370
+ Module module (model_path_);
371
+ EXPECT_FALSE (module.is_loaded ());
372
+
373
+ auto load_error = module.load ();
374
+ EXPECT_EQ (load_error, Error::Ok);
375
+ EXPECT_TRUE (module.is_loaded ());
376
+
377
+ program = module.program ();
378
+ }
379
+ EXPECT_TRUE (program != nullptr );
380
+
381
+ auto thread = [](std::shared_ptr<Program> program,
382
+ const std::array<float , 2 >& input) {
383
+ Module module (program);
384
+ std::array<int32_t , 2 > sizes{1 , 2 };
385
+ TensorImpl tensor (
386
+ ScalarType::Float, sizes.size (), sizes.data (), (void *)input.data ());
387
+
388
+ const auto result = module.forward ({EValue (Tensor (&tensor))});
389
+ EXPECT_TRUE (result.ok ());
390
+
391
+ const auto data = result->at (0 ).toTensor ().const_data_ptr <float >();
392
+ EXPECT_NEAR (data[0 ], (input[0 ] + input[1 ]) / 2.0 , 1e-5 );
393
+ };
394
+
395
+ std::thread t1 (thread, program, std::array<float , 2 >{1 , 2 });
396
+ std::thread t2 (thread, program, std::array<float , 2 >{2 , 3 });
397
+ std::thread t3 (thread, program, std::array<float , 2 >{3 , 4 });
398
+ std::thread t4 (thread, program, std::array<float , 2 >{4 , 5 });
399
+ std::thread t5 (thread, program, std::array<float , 2 >{5 , 6 });
400
+
401
+ t1.join ();
402
+ t2.join ();
403
+ t3.join ();
404
+ t4.join ();
405
+ t5.join ();
406
+ }
407
+
254
408
} // namespace torch::executor
0 commit comments