Skip to content

Commit 400cebb

Browse files
cccclaifacebook-github-bot
authored andcommitted
Expose method name as part of backend init context (#6622)
Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result<DelegateHandle*> QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef<CompileSpec> compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Reviewed By: dbort Differential Revision: D65386597
1 parent e332e2a commit 400cebb

File tree

4 files changed

+77
-8
lines changed

4 files changed

+77
-8
lines changed

runtime/backend/backend_execution_context.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class BackendExecutionContext final {
2121
public:
2222
BackendExecutionContext(
2323
EventTracer* event_tracer = nullptr,
24-
MemoryAllocator* temp_allocator = nullptr)
25-
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
24+
MemoryAllocator* temp_allocator = nullptr,
25+
const char* method_name = nullptr)
26+
: event_tracer_(event_tracer),
27+
temp_allocator_(temp_allocator),
28+
method_name_(method_name) {}
2629

2730
/**
2831
* Returns a pointer to an instance of EventTracer to do profiling/debugging
@@ -52,9 +55,17 @@ class BackendExecutionContext final {
5255
return temp_allocator_;
5356
}
5457

58+
/**
59+
* Get the name of the executing method from the ExecuTorch runtime.
60+
*/
61+
const char* get_method_name() const {
62+
return method_name_;
63+
}
64+
5565
private:
5666
EventTracer* event_tracer_ = nullptr;
5767
MemoryAllocator* temp_allocator_ = nullptr;
68+
const char* method_name_ = nullptr;
5869
};
5970

6071
} // namespace runtime

runtime/backend/backend_init_context.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(
22+
MemoryAllocator* runtime_allocator,
23+
const char* method_name = nullptr)
24+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2325

2426
/** Get the runtime allocator passed from Method. It's the same runtime
2527
* executor used by the standard executor runtime and the life span is the
@@ -29,8 +31,20 @@ class BackendInitContext final {
2931
return runtime_allocator_;
3032
}
3133

34+
/** Get the loaded method name from ExecuTorch runtime. Usually it's
35+
* "forward", however, if there are multiple methods in the .pte file, it can
36+
* be different. One example is that we may have prefill and decode methods in
37+
* the same .pte file. In this case, when client loads "prefill" method, the
38+
* `get_method_name` function will return "prefill", when client loads
39+
* "decode" method, the `get_method_name` function will return "decode".
40+
*/
41+
const char* get_method_name() const {
42+
return method_name_;
43+
}
44+
3245
private:
3346
MemoryAllocator* runtime_allocator_ = nullptr;
47+
const char* method_name_ = nullptr;
3448
};
3549

3650
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,9 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626626

627627
for (size_t i = 0; i < n_delegate; ++i) {
628628
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
629+
BackendInitContext backend_init_context(
630+
method_allocator,
631+
/*method_name=*/serialization_plan_->name()->c_str());
630632
Error err = BackendDelegate::Init(
631633
delegate, program_, backend_init_context, &delegates_[i]);
632634
if (err != Error::Ok) {
@@ -1097,8 +1099,9 @@ Error Method::execute_instruction() {
10971099
n_delegate_,
10981100
step_state_.instr_idx);
10991101
BackendExecutionContext backend_execution_context(
1100-
/*event_tracer*/ event_tracer_,
1101-
/*temp_allocator*/ temp_allocator_);
1102+
/*event_tracer=*/event_tracer_,
1103+
/*temp_allocator=*/temp_allocator_,
1104+
/*method_name=*/serialization_plan_->name()->c_str());
11021105
err = delegates_[delegate_idx].Execute(
11031106
backend_execution_context,
11041107
chain.argument_lists_[step_state_.instr_idx].data());

runtime/executor/test/backend_integration_test.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface {
9595
}
9696

9797
Error execute(
98-
ET_UNUSED BackendExecutionContext& context,
98+
BackendExecutionContext& context,
9999
DelegateHandle* handle,
100100
EValue** args) const override {
101101
if (execute_fn_) {
@@ -530,6 +530,47 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
530530
EXPECT_EQ(backend_load_was_called, using_segments());
531531
}
532532

533+
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
534+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
535+
ASSERT_EQ(loader.error(), Error::Ok);
536+
const void* processed_data = nullptr;
537+
StubBackend::singleton().install_init(
538+
[&](FreeableBuffer* processed,
539+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
540+
ET_UNUSED BackendInitContext& backend_init_context)
541+
-> Result<DelegateHandle*> {
542+
auto method_name = backend_init_context.get_method_name();
543+
// Ensure that we can get the method name during init via context
544+
EXPECT_STREQ(method_name, "forward");
545+
processed_data = processed->data();
546+
return nullptr;
547+
});
548+
Result<Program> program = Program::load(&loader.get());
549+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
550+
Result<Method> method = program->load_method("forward", &mmm.get());
551+
EXPECT_TRUE(method.ok());
552+
ASSERT_EQ(program.error(), Error::Ok);
553+
}
554+
555+
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
556+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
557+
ASSERT_EQ(loader.error(), Error::Ok);
558+
StubBackend::singleton().install_execute(
559+
[&](BackendExecutionContext& backend_execution_context, ET_UNUSED DelegateHandle* handle, ET_UNUSED EValue** args)-> Error {
560+
// Ensure that we can get the method name during execution via context
561+
auto method_name = backend_execution_context.get_method_name();
562+
EXPECT_STREQ(method_name, "forward");
563+
return Error::Ok;
564+
});
565+
Result<Program> program = Program::load(&loader.get());
566+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
567+
Result<Method> method = program->load_method("forward", &mmm.get());
568+
EXPECT_TRUE(method.ok());
569+
Error err = method->execute();
570+
ASSERT_EQ(err, Error::Ok);
571+
572+
}
573+
533574
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
534575
// - Errors during init() or execute() result in runtime init/execution failures
535576
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)