File tree Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats {
82
82
long aggregate_sampling_timer_start_timestamp = 0 ;
83
83
};
84
84
85
- static constexpr auto kTopp = 0 .9f ;
86
-
87
85
inline std::string stats_to_json_string (const Stats& stats) {
88
86
std::stringstream ss;
89
87
ss << " {\" prompt_tokens\" :" << stats.num_prompt_tokens << " ,"
@@ -168,7 +166,6 @@ namespace executorch {
168
166
namespace llm {
169
167
// TODO(T197294990): Remove these deprecated aliases once all users have moved
170
168
// to the new `::executorch` namespaces.
171
- using ::executorch::extension::llm::kTopp ;
172
169
using ::executorch::extension::llm::print_report;
173
170
using ::executorch::extension::llm::Stats;
174
171
} // namespace llm
Original file line number Diff line number Diff line change 34
34
35
35
#include < executorch/extension/llm/sampler/sampler.h>
36
36
#include < algorithm>
37
+ #include < ctime>
37
38
38
39
namespace executorch {
39
40
namespace extension {
@@ -129,6 +130,12 @@ Sampler::Sampler(
129
130
topp_(topp),
130
131
rng_state_(rng_seed) {}
131
132
133
+ Sampler::Sampler (int vocab_size, float temperature)
134
+ : vocab_size_(vocab_size),
135
+ inv_temperature_(static_cast <bool >(temperature) ? 1.0f / temperature : 0),
136
+ topp_(kTopp ),
137
+ rng_state_(std::time(nullptr )) {}
138
+
132
139
template <typename T>
133
140
static void softmax (T* x, int size) {
134
141
// find max value (for numerical stability)
Original file line number Diff line number Diff line change @@ -26,6 +26,8 @@ namespace extension {
26
26
namespace llm {
27
27
// A simple llama2 sampler.
28
28
29
+ inline constexpr auto kTopp = 0 .9f ;
30
+
29
31
template <typename T>
30
32
struct ET_EXPERIMENTAL ProbIndex {
31
33
T prob;
@@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler {
40
42
float topp,
41
43
unsigned long long rng_seed);
42
44
45
+ Sampler (int32_t vocab_size, float temperature);
46
+
43
47
template <typename T>
44
48
int32_t sample (T* logits);
45
49
@@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex;
71
75
using ::executorch::extension::llm::Sampler;
72
76
} // namespace executor
73
77
} // namespace torch
78
+
79
+ namespace executorch ::llm {
80
+ // TODO(T197294990): Remove these deprecated aliases once all users have moved
81
+ // to the new `::executorch::extension::llm` namespaces.
82
+ using ::executorch::extension::llm::kTopp ;
83
+ } // namespace executorch::llm
You can’t perform that action at this time.
0 commit comments