Skip to content

Commit 6cb5726

Browse files
authored
Allow sharing Program among several Modules.
Differential Revision: D61370459 Pull Request resolved: #4749
1 parent 45e9f6b commit 6cb5726

File tree

4 files changed

+222
-19
lines changed

4 files changed

+222
-19
lines changed

extension/module/module.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@ Module::Module(
6363
runtime_init();
6464
}
6565

66+
Module::Module(
67+
std::shared_ptr<Program> program,
68+
std::unique_ptr<MemoryAllocator> memory_allocator,
69+
std::unique_ptr<MemoryAllocator> temp_allocator,
70+
std::unique_ptr<EventTracer> event_tracer)
71+
: program_(std::move(program)),
72+
memory_allocator_(
73+
memory_allocator ? std::move(memory_allocator)
74+
: std::make_unique<util::MallocMemoryAllocator>()),
75+
temp_allocator_(
76+
temp_allocator ? std::move(temp_allocator)
77+
: std::make_unique<util::MallocMemoryAllocator>()),
78+
event_tracer_(std::move(event_tracer)) {
79+
runtime_init();
80+
}
81+
6682
Error Module::load(const Program::Verification verification) {
6783
if (!is_loaded()) {
6884
if (!data_loader_) {
@@ -86,8 +102,13 @@ Error Module::load(const Program::Verification verification) {
86102
break;
87103
}
88104
};
89-
program_ =
105+
auto program =
90106
ET_UNWRAP_UNIQUE(Program::load(data_loader_.get(), verification));
107+
program_ = std::shared_ptr<Program>(
108+
program.release(),
109+
[data_loader = std::move(data_loader_)](Program* pointer) {
110+
delete pointer;
111+
});
91112
}
92113
return Error::Ok;
93114
}

extension/module/module.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,30 @@ class Module final {
6363
std::unique_ptr<MemoryAllocator> memory_allocator = nullptr,
6464
std::unique_ptr<MemoryAllocator> temp_allocator = nullptr,
6565
std::unique_ptr<EventTracer> event_tracer = nullptr);
66+
67+
/**
68+
* Constructs an instance using an existing shared program.
69+
*
70+
* @param[in] program The shared program to use. It's required the data loader
71+
* the program uses is valid for the lifetime of the program.
72+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
73+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
74+
* temporary data.
75+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
76+
*/
77+
explicit Module(
78+
std::shared_ptr<Program> program,
79+
std::unique_ptr<MemoryAllocator> memory_allocator = nullptr,
80+
std::unique_ptr<MemoryAllocator> temp_allocator = nullptr,
81+
std::unique_ptr<EventTracer> event_tracer = nullptr);
82+
6683
Module(const Module&) = delete;
6784
Module& operator=(const Module&) = delete;
6885
Module(Module&&) = delete;
6986
Module& operator=(Module&&) = delete;
7087

7188
/**
72-
* Loads the program using the specified data loader and memory allocator.
89+
* Loads the program if needed.
7390
*
7491
* @param[in] verification The type of verification to do before returning
7592
* success.
@@ -88,6 +105,16 @@ class Module final {
88105
*/
89106
bool is_loaded() const;
90107

108+
/**
109+
* Get the program. The data loader used by the program is guaranteed to be
110+
* valid for the lifetime of the program.
111+
*
112+
* @returns Shared pointer to the program or nullptr if it's not yet loaded.
113+
*/
114+
std::shared_ptr<Program> program() const {
115+
return program_;
116+
}
117+
91118
/**
92119
* Get a list of method names available in the loaded program.
93120
* Loads the program and method if needed.
@@ -253,11 +280,11 @@ class Module final {
253280
private:
254281
std::string file_path_;
255282
LoadMode load_mode_{LoadMode::MmapUseMlock};
283+
std::shared_ptr<Program> program_;
256284
std::unique_ptr<DataLoader> data_loader_;
257285
std::unique_ptr<MemoryAllocator> memory_allocator_;
258286
std::unique_ptr<MemoryAllocator> temp_allocator_;
259287
std::unique_ptr<EventTracer> event_tracer_;
260-
std::unique_ptr<Program> program_;
261288
std::unordered_map<std::string, MethodHolder> methods_;
262289
};
263290

extension/module/test/module_test.cpp

Lines changed: 170 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,32 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <gtest/gtest.h>
10-
119
#include <executorch/extension/module/module.h>
10+
1211
#include <array>
12+
#include <thread>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <executorch/extension/data_loader/file_data_loader.h>
1317

1418
using namespace ::testing;
1519

1620
namespace torch::executor {
1721

18-
class ModuleTest : public ::testing::Test {};
22+
class ModuleTest : public ::testing::Test {
23+
protected:
24+
static void SetUpTestSuite() {
25+
model_path_ = std::getenv("RESOURCES_PATH") + std::string("/model.pte");
26+
}
27+
28+
static std::string model_path_;
29+
};
30+
31+
std::string ModuleTest::model_path_;
1932

2033
TEST_F(ModuleTest, TestLoad) {
21-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
34+
Module module(model_path_);
2235

2336
EXPECT_FALSE(module.is_loaded());
2437
const auto error = module.load();
@@ -43,7 +56,7 @@ TEST_F(ModuleTest, TestLoadCorruptedFile) {
4356
}
4457

4558
TEST_F(ModuleTest, TestMethodNames) {
46-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
59+
Module module(model_path_);
4760

4861
const auto method_names = module.method_names();
4962
EXPECT_TRUE(method_names.ok());
@@ -58,7 +71,7 @@ TEST_F(ModuleTest, TestNonExistentMethodNames) {
5871
}
5972

6073
TEST_F(ModuleTest, TestLoadMethod) {
61-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
74+
Module module(model_path_);
6275

6376
EXPECT_FALSE(module.is_method_loaded("forward"));
6477
const auto error = module.load_method("forward");
@@ -68,7 +81,7 @@ TEST_F(ModuleTest, TestLoadMethod) {
6881
}
6982

7083
TEST_F(ModuleTest, TestLoadNonExistentMethod) {
71-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
84+
Module module(model_path_);
7285

7386
const auto error = module.load_method("backward");
7487
EXPECT_NE(error, Error::Ok);
@@ -77,7 +90,7 @@ TEST_F(ModuleTest, TestLoadNonExistentMethod) {
7790
}
7891

7992
TEST_F(ModuleTest, TestMethodMeta) {
80-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
93+
Module module(model_path_);
8194

8295
const auto meta = module.method_meta("forward");
8396
EXPECT_TRUE(meta.ok());
@@ -109,7 +122,7 @@ TEST_F(ModuleTest, TestNonExistentMethodMeta) {
109122
}
110123

111124
TEST_F(ModuleTest, TestExecute) {
112-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
125+
Module module(model_path_);
113126

114127
std::array<float, 2> input{1, 2};
115128
std::array<int32_t, 2> sizes{1, 2};
@@ -127,7 +140,7 @@ TEST_F(ModuleTest, TestExecute) {
127140
}
128141

129142
TEST_F(ModuleTest, TestExecutePreload) {
130-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
143+
Module module(model_path_);
131144

132145
const auto error = module.load();
133146
EXPECT_EQ(error, Error::Ok);
@@ -146,7 +159,7 @@ TEST_F(ModuleTest, TestExecutePreload) {
146159
}
147160

148161
TEST_F(ModuleTest, TestExecutePreload_method) {
149-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
162+
Module module(model_path_);
150163

151164
const auto error = module.load_method("forward");
152165
EXPECT_EQ(error, Error::Ok);
@@ -165,7 +178,7 @@ TEST_F(ModuleTest, TestExecutePreload_method) {
165178
}
166179

167180
TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) {
168-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
181+
Module module(model_path_);
169182

170183
const auto load_error = module.load();
171184
EXPECT_EQ(load_error, Error::Ok);
@@ -203,7 +216,7 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
203216
}
204217

205218
TEST_F(ModuleTest, TestGet) {
206-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
219+
Module module(model_path_);
207220

208221
std::array<float, 2> input{1, 2};
209222
std::array<int32_t, 2> sizes{1, 2};
@@ -218,8 +231,7 @@ TEST_F(ModuleTest, TestGet) {
218231
}
219232

220233
TEST_F(ModuleTest, TestForward) {
221-
auto module = std::make_unique<Module>(
222-
std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
234+
auto module = std::make_unique<Module>(model_path_);
223235

224236
std::array<float, 2> input{1, 2};
225237
std::array<int32_t, 2> sizes{1, 2};
@@ -244,11 +256,153 @@ TEST_F(ModuleTest, TestForward) {
244256
}
245257

246258
TEST_F(ModuleTest, TestForwardWithInvalidInputs) {
247-
Module module(std::getenv("RESOURCES_PATH") + std::string("/model.pte"));
259+
Module module(model_path_);
248260

249261
const auto result = module.forward({EValue()});
250262

251263
EXPECT_FALSE(result.ok());
252264
}
253265

266+
TEST_F(ModuleTest, TestProgramSharingBetweenModules) {
267+
Module module1(model_path_);
268+
EXPECT_FALSE(module1.is_loaded());
269+
270+
auto load_error = module1.load();
271+
EXPECT_EQ(load_error, Error::Ok);
272+
EXPECT_TRUE(module1.is_loaded());
273+
274+
Module module2(module1.program());
275+
EXPECT_TRUE(module2.is_loaded());
276+
277+
auto method_names1 = module1.method_names();
278+
EXPECT_TRUE(method_names1.ok());
279+
280+
auto method_names2 = module2.method_names();
281+
EXPECT_TRUE(method_names2.ok());
282+
EXPECT_EQ(method_names1.get(), method_names2.get());
283+
284+
auto load_method_error = module1.load_method("forward");
285+
EXPECT_EQ(load_method_error, Error::Ok);
286+
EXPECT_TRUE(module1.is_method_loaded("forward"));
287+
EXPECT_FALSE(module2.is_method_loaded("forward"));
288+
289+
auto load_method_error2 = module2.load_method("forward");
290+
EXPECT_EQ(load_method_error2, Error::Ok);
291+
EXPECT_TRUE(module2.is_method_loaded("forward"));
292+
}
293+
294+
TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) {
295+
auto loader = util::FileDataLoader::from(model_path_.c_str());
296+
EXPECT_TRUE(loader.ok());
297+
auto data_loader =
298+
std::make_unique<util::FileDataLoader>(std::move(loader.get()));
299+
300+
auto module1 = std::make_unique<Module>(std::move(data_loader));
301+
302+
auto load_error = module1->load();
303+
EXPECT_EQ(load_error, Error::Ok);
304+
EXPECT_TRUE(module1->is_loaded());
305+
306+
std::array<float, 2> input{1, 2};
307+
std::array<int32_t, 2> sizes{1, 2};
308+
TensorImpl tensor(
309+
ScalarType::Float, sizes.size(), sizes.data(), input.data());
310+
311+
auto result1 = module1->execute("forward", {EValue(Tensor(&tensor))});
312+
EXPECT_TRUE(result1.ok());
313+
314+
auto module2 = std::make_unique<Module>(module1->program());
315+
316+
auto result2 = module2->execute("forward", {EValue(Tensor(&tensor))});
317+
EXPECT_TRUE(result2.ok());
318+
319+
module1 = std::make_unique<Module>("/path/to/nonexistent/file.pte");
320+
EXPECT_FALSE(module1->is_loaded());
321+
322+
auto result3 = module2->execute("forward", {EValue(Tensor(&tensor))});
323+
EXPECT_TRUE(result3.ok());
324+
}
325+
326+
TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) {
327+
std::shared_ptr<Program> shared_program;
328+
329+
{
330+
auto loader = util::FileDataLoader::from(model_path_.c_str());
331+
EXPECT_TRUE(loader.ok());
332+
auto data_loader =
333+
std::make_unique<util::FileDataLoader>(std::move(loader.get()));
334+
auto* data_loader_ptr = data_loader.get();
335+
336+
Module module(std::move(data_loader));
337+
338+
auto load_error = module.load();
339+
EXPECT_EQ(load_error, Error::Ok);
340+
EXPECT_TRUE(module.is_loaded());
341+
342+
shared_program = module.program();
343+
EXPECT_NE(shared_program, nullptr);
344+
345+
EXPECT_NE(data_loader_ptr, nullptr);
346+
}
347+
348+
EXPECT_NE(shared_program, nullptr);
349+
350+
Module module(shared_program);
351+
352+
EXPECT_EQ(module.program(), shared_program);
353+
354+
std::array<float, 2> input{1, 2};
355+
std::array<int32_t, 2> sizes{1, 2};
356+
TensorImpl tensor(
357+
ScalarType::Float, sizes.size(), sizes.data(), input.data());
358+
359+
auto result = module.execute("forward", {EValue(Tensor(&tensor))});
360+
EXPECT_TRUE(result.ok());
361+
362+
auto data = result->at(0).toTensor().const_data_ptr<float>();
363+
364+
EXPECT_NEAR(data[0], 1.5, 1e-5);
365+
}
366+
367+
TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) {
368+
std::shared_ptr<Program> program;
369+
{
370+
Module module(model_path_);
371+
EXPECT_FALSE(module.is_loaded());
372+
373+
auto load_error = module.load();
374+
EXPECT_EQ(load_error, Error::Ok);
375+
EXPECT_TRUE(module.is_loaded());
376+
377+
program = module.program();
378+
}
379+
EXPECT_TRUE(program != nullptr);
380+
381+
auto thread = [](std::shared_ptr<Program> program,
382+
const std::array<float, 2>& input) {
383+
Module module(program);
384+
std::array<int32_t, 2> sizes{1, 2};
385+
TensorImpl tensor(
386+
ScalarType::Float, sizes.size(), sizes.data(), (void*)input.data());
387+
388+
const auto result = module.forward({EValue(Tensor(&tensor))});
389+
EXPECT_TRUE(result.ok());
390+
391+
const auto data = result->at(0).toTensor().const_data_ptr<float>();
392+
EXPECT_NEAR(data[0], (input[0] + input[1]) / 2.0, 1e-5);
393+
};
394+
395+
std::thread t1(thread, program, std::array<float, 2>{1, 2});
396+
std::thread t2(thread, program, std::array<float, 2>{2, 3});
397+
std::thread t3(thread, program, std::array<float, 2>{3, 4});
398+
std::thread t4(thread, program, std::array<float, 2>{4, 5});
399+
std::thread t5(thread, program, std::array<float, 2>{5, 6});
400+
401+
t1.join();
402+
t2.join();
403+
t3.join();
404+
t4.join();
405+
t5.join();
406+
}
407+
254408
} // namespace torch::executor

extension/module/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def define_common_targets():
1414
],
1515
deps = [
1616
"//executorch/kernels/portable:generated_lib",
17+
"//executorch/extension/data_loader:file_data_loader",
1718
"//executorch/extension/module:module",
1819
],
1920
env = {

0 commit comments

Comments
 (0)