Skip to content

Commit 8b948e8

Browse files
authored
Add a convenient constructor
Differential Revision: D71956172 Pull Request resolved: #9707
1 parent ec1cd04 commit 8b948e8

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

extension/llm/runner/stats.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats {
8282
long aggregate_sampling_timer_start_timestamp = 0;
8383
};
8484

85-
static constexpr auto kTopp = 0.9f;
86-
8785
inline std::string stats_to_json_string(const Stats& stats) {
8886
std::stringstream ss;
8987
ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
@@ -168,7 +166,6 @@ namespace executorch {
168166
namespace llm {
169167
// TODO(T197294990): Remove these deprecated aliases once all users have moved
170168
// to the new `::executorch` namespaces.
171-
using ::executorch::extension::llm::kTopp;
172169
using ::executorch::extension::llm::print_report;
173170
using ::executorch::extension::llm::Stats;
174171
} // namespace llm

extension/llm/sampler/sampler.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include <executorch/extension/llm/sampler/sampler.h>
3636
#include <algorithm>
37+
#include <ctime>
3738

3839
namespace executorch {
3940
namespace extension {
@@ -129,6 +130,12 @@ Sampler::Sampler(
129130
topp_(topp),
130131
rng_state_(rng_seed) {}
131132

133+
Sampler::Sampler(int vocab_size, float temperature)
134+
: vocab_size_(vocab_size),
135+
inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
136+
topp_(kTopp),
137+
rng_state_(std::time(nullptr)) {}
138+
132139
template <typename T>
133140
static void softmax(T* x, int size) {
134141
// find max value (for numerical stability)

extension/llm/sampler/sampler.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace extension {
2626
namespace llm {
2727
// A simple llama2 sampler.
2828

29+
inline constexpr auto kTopp = 0.9f;
30+
2931
template <typename T>
3032
struct ET_EXPERIMENTAL ProbIndex {
3133
T prob;
@@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler {
4042
float topp,
4143
unsigned long long rng_seed);
4244

45+
Sampler(int32_t vocab_size, float temperature);
46+
4347
template <typename T>
4448
int32_t sample(T* logits);
4549

@@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex;
7175
using ::executorch::extension::llm::Sampler;
7276
} // namespace executor
7377
} // namespace torch
78+
79+
namespace executorch::llm {
80+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
81+
// to the new `::executorch::extension::llm` namespaces.
82+
using ::executorch::extension::llm::kTopp;
83+
} // namespace executorch::llm

0 commit comments

Comments
 (0)