Skip to content

Commit af3e2be

Browse files
cccclaifacebook-github-bot
authored andcommitted
Expose method name as part of backend init context (#6622)
Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result<DelegateHandle*> QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef<CompileSpec> compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Differential Revision: D65386597
1 parent cfcf13b commit af3e2be

File tree

5 files changed

+75
-8
lines changed

5 files changed

+75
-8
lines changed

runtime/backend/backend_execution_context.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class BackendExecutionContext final {
2121
public:
2222
BackendExecutionContext(
2323
EventTracer* event_tracer = nullptr,
24-
MemoryAllocator* temp_allocator = nullptr)
25-
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
24+
MemoryAllocator* temp_allocator = nullptr,
25+
const char* method_name = nullptr)
26+
: event_tracer_(event_tracer), temp_allocator_(temp_allocator), method_name_(method_name) {}
2627

2728
/**
2829
* Returns a pointer to an instance of EventTracer to do profiling/debugging
@@ -52,9 +53,17 @@ class BackendExecutionContext final {
5253
return temp_allocator_;
5354
}
5455

56+
/**
57+
* Get the loaded method name from ExecuTorch runtime.
58+
*/
59+
const char* get_method_name() {
60+
return method_name_;
61+
}
62+
5563
private:
5664
EventTracer* event_tracer_ = nullptr;
5765
MemoryAllocator* temp_allocator_ = nullptr;
66+
const char* method_name_ = nullptr;
5867
};
5968

6069
} // namespace runtime

runtime/backend/backend_init_context.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(MemoryAllocator* runtime_allocator, const char* method_name = nullptr)
22+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2323

2424
/** Get the runtime allocator passed from Method. It's the same runtime
2525
* executor used by the standard executor runtime and the life span is the
@@ -28,9 +28,21 @@ class BackendInitContext final {
2828
MemoryAllocator* get_runtime_allocator() {
2929
return runtime_allocator_;
3030
}
31+
32+
/** Get the loaded method name from ExecuTorch runtime. Usually it's "forward",
33+
* however, if there are multiple methods in the .pte file, it can be different.
34+
* One example is that we may have prefill and decode methods in the same .pte file.
35+
* In this case, when client loads "prefill" method, the `get_method_name` function will
36+
* return "prefill", when client loads "decode" method, the `get_method_name` function will
37+
* return "decode".
38+
*/
39+
const char* get_method_name() {
40+
return method_name_;
41+
}
3142

3243
private:
3344
MemoryAllocator* runtime_allocator_ = nullptr;
45+
const char* method_name_ = nullptr;
3446
};
3547

3648
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
589589
EXECUTORCH_SCOPE_PROF("Method::init");
590590
internal::EventTracerProfileMethodScope event_tracer_profile_scope =
591591
internal::EventTracerProfileMethodScope(event_tracer_, "Method::init");
592+
method_name_ = s_plan->name()->c_str();
592593
ET_CHECK_OR_RETURN_ERROR(
593594
// Don't use !initialized() here because we also want to fail on the
594595
// InitializationFailed state.
@@ -626,7 +627,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626627

627628
for (size_t i = 0; i < n_delegate; ++i) {
628629
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
630+
BackendInitContext backend_init_context(method_allocator, method_name_);
630631
Error err = BackendDelegate::Init(
631632
delegate, program_, backend_init_context, &delegates_[i]);
632633
if (err != Error::Ok) {
@@ -1098,7 +1099,8 @@ Error Method::execute_instruction() {
10981099
step_state_.instr_idx);
10991100
BackendExecutionContext backend_execution_context(
11001101
/*event_tracer*/ event_tracer_,
1101-
/*temp_allocator*/ temp_allocator_);
1102+
/*temp_allocator*/ temp_allocator_,
1103+
/*method_name*/method_name_);
11021104
err = delegates_[delegate_idx].Execute(
11031105
backend_execution_context,
11041106
chain.argument_lists_[step_state_.instr_idx].data());
@@ -1238,7 +1240,7 @@ Error Method::step() {
12381240
step_state_.chain_idx += 1;
12391241
return Error::Ok;
12401242
}
1241-
1243+
12421244
auto status = execute_instruction();
12431245
if (status != Error::Ok) {
12441246
return status;

runtime/executor/method.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class Method final {
6262
delegates_(rhs.delegates_),
6363
n_chains_(rhs.n_chains_),
6464
chains_(rhs.chains_),
65+
method_name_(rhs.method_name_),
6566
init_state_(rhs.init_state_) {
6667
// Required: clear out fields that the dtor looks at, so that we don't free
6768
// anything twice.
@@ -80,6 +81,7 @@ class Method final {
8081
rhs.event_tracer_ = nullptr;
8182
rhs.n_chains_ = 0;
8283
rhs.chains_ = nullptr;
84+
rhs.method_name_ = nullptr;
8385
}
8486

8587
/**
@@ -328,6 +330,7 @@ class Method final {
328330

329331
size_t n_chains_;
330332
Chain* chains_;
333+
const char* method_name_;
331334

332335
InitializationState init_state_;
333336

runtime/executor/test/backend_integration_test.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface {
9595
}
9696

9797
Error execute(
98-
ET_UNUSED BackendExecutionContext& context,
98+
BackendExecutionContext& context,
9999
DelegateHandle* handle,
100100
EValue** args) const override {
101101
if (execute_fn_) {
@@ -530,6 +530,47 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
530530
EXPECT_EQ(backend_load_was_called, using_segments());
531531
}
532532

533+
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
534+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
535+
ASSERT_EQ(loader.error(), Error::Ok);
536+
const void* processed_data = nullptr;
537+
StubBackend::singleton().install_init(
538+
[&](FreeableBuffer* processed,
539+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
540+
ET_UNUSED BackendInitContext& backend_init_context)
541+
-> Result<DelegateHandle*> {
542+
auto method_name = backend_init_context.get_method_name();
543+
// Ensure that we can get the method name during init via context
544+
EXPECT_EQ(strcmp(method_name, "forward"), 0);
545+
processed_data = processed->data();
546+
return nullptr;
547+
});
548+
Result<Program> program = Program::load(&loader.get());
549+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
550+
Result<Method> method = program->load_method("forward", &mmm.get());
551+
EXPECT_TRUE(method.ok());
552+
ASSERT_EQ(program.error(), Error::Ok);
553+
}
554+
555+
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
556+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
557+
ASSERT_EQ(loader.error(), Error::Ok);
558+
StubBackend::singleton().install_execute(
559+
[&](BackendExecutionContext& backend_execution_context, ET_UNUSED DelegateHandle* handle, ET_UNUSED EValue** args)-> Error {
560+
// Ensure that we can get the method name during init via context
561+
auto method_name = backend_execution_context.get_method_name();
562+
EXPECT_EQ(strcmp(method_name, "forward"), 0);
563+
return Error::Ok;
564+
});
565+
Result<Program> program = Program::load(&loader.get());
566+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
567+
Result<Method> method = program->load_method("forward", &mmm.get());
568+
EXPECT_TRUE(method.ok());
569+
Error err = method->execute();
570+
ASSERT_EQ(program.error(), Error::Ok);
571+
572+
}
573+
533574
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
534575
// - Errors during init() or execute() result in runtime init/execution failures
535576
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)