Skip to content

Commit 2205c9f

Browse files
cccclaifacebook-github-bot
authored andcommitted
Expose method name as part of backend init context (#6622)
Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result<DelegateHandle*> QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef<CompileSpec> compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Differential Revision: D65386597
1 parent e332e2a commit 2205c9f

File tree

4 files changed

+72
-9
lines changed

4 files changed

+72
-9
lines changed

runtime/backend/backend_execution_context.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class BackendExecutionContext final {
2121
public:
2222
BackendExecutionContext(
2323
EventTracer* event_tracer = nullptr,
24-
MemoryAllocator* temp_allocator = nullptr)
25-
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
24+
MemoryAllocator* temp_allocator = nullptr,
25+
const char* method_name = nullptr)
26+
: event_tracer_(event_tracer), temp_allocator_(temp_allocator), method_name_(method_name) {}
2627

2728
/**
2829
* Returns a pointer to an instance of EventTracer to do profiling/debugging
@@ -52,9 +53,17 @@ class BackendExecutionContext final {
5253
return temp_allocator_;
5354
}
5455

56+
/**
57+
* Get the name of the executing method from the ExecuTorch runtime.
58+
*/
59+
const char* get_method_name() const {
60+
return method_name_;
61+
}
62+
5563
private:
5664
EventTracer* event_tracer_ = nullptr;
5765
MemoryAllocator* temp_allocator_ = nullptr;
66+
const char* method_name_ = nullptr;
5867
};
5968

6069
} // namespace runtime

runtime/backend/backend_init_context.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(MemoryAllocator* runtime_allocator, const char* method_name = nullptr)
22+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2323

2424
/** Get the runtime allocator passed from Method. It's the same runtime
2525
* executor used by the standard executor runtime and the life span is the
@@ -28,9 +28,21 @@ class BackendInitContext final {
2828
MemoryAllocator* get_runtime_allocator() {
2929
return runtime_allocator_;
3030
}
31+
32+
/** Get the loaded method name from ExecuTorch runtime. Usually it's "forward",
33+
* however, if there are multiple methods in the .pte file, it can be different.
34+
* One example is that we may have prefill and decode methods in the same .pte file.
35+
* In this case, when client loads "prefill" method, the `get_method_name` function will
36+
* return "prefill", when client loads "decode" method, the `get_method_name` function will
37+
* return "decode".
38+
*/
39+
const char* get_method_name() const {
40+
return method_name_;
41+
}
3142

3243
private:
3344
MemoryAllocator* runtime_allocator_ = nullptr;
45+
const char* method_name_ = nullptr;
3446
};
3547

3648
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626626

627627
for (size_t i = 0; i < n_delegate; ++i) {
628628
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
629+
BackendInitContext backend_init_context(method_allocator, /*method_name=*/serialization_plan_->name()->c_str());
630630
Error err = BackendDelegate::Init(
631631
delegate, program_, backend_init_context, &delegates_[i]);
632632
if (err != Error::Ok) {
@@ -1097,8 +1097,9 @@ Error Method::execute_instruction() {
10971097
n_delegate_,
10981098
step_state_.instr_idx);
10991099
BackendExecutionContext backend_execution_context(
1100-
/*event_tracer*/ event_tracer_,
1101-
/*temp_allocator*/ temp_allocator_);
1100+
/*event_tracer=*/event_tracer_,
1101+
/*temp_allocator=*/temp_allocator_,
1102+
/*method_name=*/serialization_plan_->name()->c_str());
11021103
err = delegates_[delegate_idx].Execute(
11031104
backend_execution_context,
11041105
chain.argument_lists_[step_state_.instr_idx].data());
@@ -1238,7 +1239,7 @@ Error Method::step() {
12381239
step_state_.chain_idx += 1;
12391240
return Error::Ok;
12401241
}
1241-
1242+
12421243
auto status = execute_instruction();
12431244
if (status != Error::Ok) {
12441245
return status;

runtime/executor/test/backend_integration_test.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface {
9595
}
9696

9797
Error execute(
98-
ET_UNUSED BackendExecutionContext& context,
98+
BackendExecutionContext& context,
9999
DelegateHandle* handle,
100100
EValue** args) const override {
101101
if (execute_fn_) {
@@ -530,6 +530,47 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
530530
EXPECT_EQ(backend_load_was_called, using_segments());
531531
}
532532

533+
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
534+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
535+
ASSERT_EQ(loader.error(), Error::Ok);
536+
const void* processed_data = nullptr;
537+
StubBackend::singleton().install_init(
538+
[&](FreeableBuffer* processed,
539+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
540+
ET_UNUSED BackendInitContext& backend_init_context)
541+
-> Result<DelegateHandle*> {
542+
auto method_name = backend_init_context.get_method_name();
543+
// Ensure that we can get the method name during init via context
544+
EXPECT_STREQ(method_name, "forward");
545+
processed_data = processed->data();
546+
return nullptr;
547+
});
548+
Result<Program> program = Program::load(&loader.get());
549+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
550+
Result<Method> method = program->load_method("forward", &mmm.get());
551+
EXPECT_TRUE(method.ok());
552+
ASSERT_EQ(program.error(), Error::Ok);
553+
}
554+
555+
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
556+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
557+
ASSERT_EQ(loader.error(), Error::Ok);
558+
StubBackend::singleton().install_execute(
559+
[&](BackendExecutionContext& backend_execution_context, ET_UNUSED DelegateHandle* handle, ET_UNUSED EValue** args)-> Error {
560+
// Ensure that we can get the method name during execution via context
561+
auto method_name = backend_execution_context.get_method_name();
562+
EXPECT_STREQ(method_name, "forward");
563+
return Error::Ok;
564+
});
565+
Result<Program> program = Program::load(&loader.get());
566+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
567+
Result<Method> method = program->load_method("forward", &mmm.get());
568+
EXPECT_TRUE(method.ok());
569+
Error err = method->execute();
570+
ASSERT_EQ(err, Error::Ok);
571+
572+
}
573+
533574
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
534575
// - Errors during init() or execute() result in runtime init/execution failures
535576
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)