Skip to content

Commit 95f779a

Browse files
pytorchbotlucylq
andauthored
1 parent f92ad9d commit 95f779a

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

examples/models/llama/runner/runner.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,25 @@ static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
3939
Runner::Runner(
4040
const std::string& model_path,
4141
const std::string& tokenizer_path,
42-
const float temperature)
42+
const float temperature,
43+
std::optional<const std::string> data_path)
4344
// NOTE: we observed ~2x loading performance increase on iPhone 15
4445
// and a ~5% improvement on Galaxy S22 by switching to
4546
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
4647
: temperature_(temperature),
47-
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
4848
tokenizer_path_(tokenizer_path),
4949
metadata_({
5050
{kEnableDynamicShape, false},
5151
{kMaxSeqLen, 128},
5252
{kUseKVCache, true},
5353
{kUseSDPAWithKVCache, false},
5454
}) {
55+
if (data_path.has_value()) {
56+
module_ = std::make_unique<Module>(
57+
model_path, data_path.value(), Module::LoadMode::File);
58+
} else {
59+
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
60+
}
5561
ET_LOG(
5662
Info,
5763
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",

examples/models/llama/runner/runner.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cstdint>
1515
#include <functional>
1616
#include <memory>
17+
#include <optional>
1718
#include <string>
1819
#include <unordered_map>
1920

@@ -32,7 +33,8 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
3233
explicit Runner(
3334
const std::string& model_path,
3435
const std::string& tokenizer_path,
35-
const float temperature = 0.8f);
36+
const float temperature = 0.8f,
37+
std::optional<const std::string> data_path = std::nullopt);
3638

3739
bool is_loaded() const;
3840
::executorch::runtime::Error load();

0 commit comments

Comments
 (0)