99#include " utils/json.hpp"
1010#include " utils/sample_log.h"
1111
12- class LLMPostprocess
13- {
12+ class LLMPostprocess {
1413private:
15- // 控制随机性
1614 void apply_temperature (std::vector<float > &logits, float temperature)
1715 {
18- for ( float &logit : logits)
19- {
16+ if (temperature == 0 . 0f ) temperature = 0 . 01f ;
17+ for ( float &logit : logits) {
2018 logit /= temperature;
2119 }
2220 }
2321
24- // 防止重复
2522 void apply_repetition_penalty (std::vector<float > &logits, const std::vector<int > &history, float penalty)
2623 {
27- for (int token : history)
28- {
29- if (token < logits.size ())
30- {
24+ for (int token : history) {
25+ if (token < logits.size ()) {
3126 logits[token] = logits[token] < 0 ? logits[token] * penalty : logits[token] / penalty;
3227 }
3328 }
3429 }
3530
36- void apply_repetition_penalty (std::vector<float > &logits,
37- const std::vector<int > &generated_tokens,
38- float repetition_penalty,
39- int penalty_window)
31+ void apply_repetition_penalty (std::vector<float > &logits, const std::vector<int > &generated_tokens,
32+ float repetition_penalty, int penalty_window)
4033 {
41- if (repetition_penalty == 1 .0f || generated_tokens.empty ())
42- {
43- return ; // 如果 penalty = 1.0 或者没有生成 token,则不进行修改
34+ if (repetition_penalty == 1 .0f || generated_tokens.empty ()) {
35+ return ;
4436 }
4537
4638 int start_idx = std::max (0 , (int )generated_tokens.size () - penalty_window);
4739 std::unordered_set<int > recent_tokens (generated_tokens.begin () + start_idx, generated_tokens.end ());
4840
49- for (int token : recent_tokens)
50- {
51- if (token < 0 || token >= logits.size ())
52- continue ;
41+ for (int token : recent_tokens) {
42+ if (token < 0 || token >= logits.size ()) continue ;
5343
54- if (logits[token] > 0 )
55- {
44+ if (logits[token] > 0 ) {
5645 logits[token] /= std::sqrt (repetition_penalty);
57- }
58- else
59- {
46+ } else {
6047 logits[token] *= std::sqrt (repetition_penalty);
6148 }
6249 }
6350 }
6451
65- // 增强多样性
6652 void apply_diversity_penalty (std::vector<float > &logits, const std::vector<int > &common_phrases, float penalty)
6753 {
68- for (int token : common_phrases)
69- {
70- if (token < logits.size ())
71- {
54+ for (int token : common_phrases) {
55+ if (token < logits.size ()) {
7256 logits[token] *= penalty;
7357 }
7458 }
@@ -79,46 +63,37 @@ class LLMPostprocess
7963 {
8064 std::vector<float > probs (logits.size ());
8165 float max_logit = *std::max_element (logits.begin (), logits.end ());
82- float sum = 0 .0f ;
66+ float sum = 0 .0f ;
8367
84- for (size_t i = 0 ; i < logits.size (); ++i)
85- {
68+ for (size_t i = 0 ; i < logits.size (); ++i) {
8669 probs[i] = std::exp (logits[i] - max_logit);
8770 sum += probs[i];
8871 }
8972
90- for (float &p : probs)
91- {
73+ for (float &p : probs) {
9274 p /= sum;
9375 }
9476
9577 return probs;
9678 }
9779
98- // 动态裁剪低概率 token
9980 int faster_top_p_sampling (const std::vector<float > &logits, float top_p)
10081 {
101- // 计算softmax
10282 std::vector<float > probs = softmax (logits);
10383
104- // 构建最大堆(概率和索引的配对)
10584 std::vector<std::pair<float , size_t >> prob_index;
10685 prob_index.reserve (logits.size ());
107- for (size_t i = 0 ; i < logits.size (); ++i)
108- {
86+ for (size_t i = 0 ; i < logits.size (); ++i) {
10987 prob_index.emplace_back (probs[i], i);
11088 }
111- auto cmp = [](const auto &a, const auto &b)
112- { return a.first < b.first ; };
89+ auto cmp = [](const auto &a, const auto &b) { return a.first < b.first ; };
11390 std::make_heap (prob_index.begin (), prob_index.end (), cmp);
11491
115- // 提取top-p元素
11692 std::vector<size_t > filtered_indices;
11793 std::vector<float > filtered_probs;
11894 float cumulative_prob = 0 .0f ;
11995
120- while (!prob_index.empty () && cumulative_prob < top_p)
121- {
96+ while (!prob_index.empty () && cumulative_prob < top_p) {
12297 std::pop_heap (prob_index.begin (), prob_index.end (), cmp);
12398 auto [prob, index] = prob_index.back ();
12499 prob_index.pop_back ();
@@ -127,15 +102,11 @@ class LLMPostprocess
127102 filtered_indices.push_back (index);
128103 filtered_probs.push_back (prob);
129104
130- if (cumulative_prob >= top_p)
131- break ;
105+ if (cumulative_prob >= top_p) break ;
132106 }
133107
134- // 处理边缘情况(概率全零时返回第一个元素)
135- if (filtered_indices.empty ())
136- return 0 ;
108+ if (filtered_indices.empty ()) return 0 ;
137109
138- // 使用thread_local随机数生成器(线程安全)
139110 static thread_local std::mt19937 gen (std::random_device{}());
140111 std::discrete_distribution<int > dist (filtered_probs.begin (), filtered_probs.end ());
141112 return filtered_indices[dist (gen)];
@@ -147,31 +118,26 @@ class LLMPostprocess
147118 // Sort indices by probability in descending order
148119 std::vector<size_t > indices (logits.size ());
149120 std::iota (indices.begin (), indices.end (), 0 );
150- std::sort (indices.begin (), indices.end (), [&](size_t i, size_t j)
151- { return probs[i] > probs[j]; });
121+ std::sort (indices.begin (), indices.end (), [&](size_t i, size_t j) { return probs[i] > probs[j]; });
152122
153123 // Compute cumulative probabilities
154124 float cumulative_prob = 0 .0f ;
155- size_t cut_off = 0 ;
156- for (; cut_off < indices.size (); ++cut_off)
157- {
125+ size_t cut_off = 0 ;
126+ for (; cut_off < indices.size (); ++cut_off) {
158127 cumulative_prob += probs[indices[cut_off]];
159- if (cumulative_prob >= top_p)
160- break ;
128+ if (cumulative_prob >= top_p) break ;
161129 }
162130
163131 // Keep only the top-p probabilities
164132 std::vector<size_t > filtered_indices (indices.begin (), indices.begin () + cut_off + 1 );
165133 std::vector<float > filtered_probs (filtered_indices.size ());
166- for (size_t i = 0 ; i < filtered_indices.size (); ++i)
167- {
134+ for (size_t i = 0 ; i < filtered_indices.size (); ++i) {
168135 filtered_probs[i] = probs[filtered_indices[i]];
169136 }
170137
171138 // Normalize the probabilities
172139 float filtered_sum = std::accumulate (filtered_probs.begin (), filtered_probs.end (), 0 .0f );
173- for (float &p : filtered_probs)
174- {
140+ for (float &p : filtered_probs) {
175141 p /= filtered_sum;
176142 }
177143
@@ -182,121 +148,115 @@ class LLMPostprocess
182148 return filtered_indices[dist (gen)];
183149 }
184150
185- // 限制候选 token 数
186151 int top_k_sampling (const std::vector<float > &logits, int k)
187152 {
188153 // std::vector<float> probs = softmax(logits);
189154
190- // 获取 top-k 索引
191155 std::vector<size_t > indices (logits.size ());
192156 std::iota (indices.begin (), indices.end (), 0 );
193- std::partial_sort (indices.begin (), indices.begin () + k, indices.end (), [&]( size_t i, size_t j)
194- { return logits[i] > logits[j]; });
157+ std::partial_sort (indices.begin (), indices.begin () + k, indices.end (),
158+ [&]( size_t i, size_t j) { return logits[i] > logits[j]; });
195159
196- // 仅保留 top-k 概率
197160 std::vector<size_t > filtered_indices (indices.begin (), indices.begin () + k);
198161 std::vector<float > filtered_probs (k);
199- for (size_t i = 0 ; i < k; ++i)
200- {
162+ for (size_t i = 0 ; i < k; ++i) {
201163 filtered_probs[i] = logits[filtered_indices[i]];
202164 }
203165 filtered_probs = softmax (filtered_probs);
204166
205- // 归一化
206167 float sum = std::accumulate (filtered_probs.begin (), filtered_probs.end (), 0 .0f );
207- for (float &p : filtered_probs)
208- {
168+ for (float &p : filtered_probs) {
209169 p /= sum;
210170 }
211171
212- // 采样
213172 std::random_device rd;
214173 std::mt19937 gen (rd ());
215174 std::discrete_distribution<int > dist (filtered_probs.begin (), filtered_probs.end ());
216175 return filtered_indices[dist (gen)];
217176 }
218177
219178 bool enable_temperature = false ;
220- float temperature = 1 .0f ;
179+ float temperature = 1 .0f ;
221180
222181 bool enable_repetition_penalty = false ;
223- float repetition_penalty = 1 .0f ;
224- int penalty_window = 20 ;
182+ float repetition_penalty = 1 .0f ;
183+ int penalty_window = 20 ;
225184
226185 bool enable_diversity_penalty = false ;
227186 std::vector<int > common_phrases;
228187 float diversity_penalty = 1 .0f ;
229188
230189 bool enable_top_p_sampling = false ;
231- float top_p = 1 .0f ;
190+ float top_p = 1 .0f ;
232191
233192 bool enable_top_k_sampling = false ;
234- int top_k = 1 ;
193+ int top_k = 1 ;
235194
236195public:
237- LLMPostprocess () {}
196+ LLMPostprocess ()
197+ {
198+ }
238199
239200 void set_temperature (bool enable, float temperature)
240201 {
241202 enable_temperature = enable;
242- this ->temperature = temperature;
203+ this ->temperature = temperature;
243204 }
244205
245206 void set_repetition_penalty (bool enable, float penalty, int penalty_window)
246207 {
247208 enable_repetition_penalty = enable;
248- this ->repetition_penalty = penalty;
249- this ->penalty_window = penalty_window;
209+ this ->repetition_penalty = penalty;
210+ this ->penalty_window = penalty_window;
250211 }
251212
252213 void set_diversity_penalty (bool enable, const std::vector<int > &common_phrases, float penalty)
253214 {
254215 enable_diversity_penalty = enable;
255- this ->common_phrases = common_phrases;
256- this ->diversity_penalty = penalty;
216+ this ->common_phrases = common_phrases;
217+ this ->diversity_penalty = penalty;
257218 }
258219
259220 void set_top_p_sampling (bool enable, float top_p)
260221 {
261222 enable_top_k_sampling = false ;
262223 enable_top_p_sampling = enable;
263- this ->top_p = top_p;
224+ this ->top_p = top_p;
264225 }
265226
266227 void set_top_k_sampling (bool enable, int top_k)
267228 {
268229 enable_top_p_sampling = false ;
269230 enable_top_k_sampling = enable;
270- this ->top_k = top_k;
231+ this ->top_k = top_k;
271232 }
272233
273234 bool load_config (std::string config_path)
274235 {
275236 std::ifstream config_file (config_path);
276- if (!config_file.is_open ())
277- {
237+ if (!config_file.is_open ()) {
278238 ALOGE (" config file(%s) open failed" , config_path.c_str ());
279239 return false ;
280240 }
281241 nlohmann::json config = nlohmann::json::parse (config_file);
282242 ALOGI (" load config: \n %s\n " , config.dump (4 ).c_str ());
283243
284244 enable_temperature = config[" enable_temperature" ];
285- temperature = config[" temperature" ];
245+ temperature = config[" temperature" ];
286246
287247 enable_repetition_penalty = config[" enable_repetition_penalty" ];
288- repetition_penalty = config[" repetition_penalty" ];
289- penalty_window = config[" penalty_window" ];
248+ repetition_penalty = config[" repetition_penalty" ];
249+ penalty_window = config[" penalty_window" ];
290250
291251 enable_top_p_sampling = config[" enable_top_p_sampling" ];
292- top_p = config[" top_p" ];
252+ top_p = config[" top_p" ];
293253
294254 enable_top_k_sampling = config[" enable_top_k_sampling" ];
295- top_k = config[" top_k" ];
255+ top_k = config[" top_k" ];
296256 return true ;
297257 }
298258
299- bool load_config (const nlohmann::json& config)
259+ bool load_config (const nlohmann::json & config)
300260 {
301261 if (config.is_null ()) {
302262 ALOGE (" config is null or invalid" );
@@ -341,22 +301,17 @@ class LLMPostprocess
341301
342302 int apply (std::vector<float > &logits, const std::vector<int > &history)
343303 {
344- if (enable_temperature)
345- apply_temperature (logits, temperature);
346- if (enable_repetition_penalty)
347- apply_repetition_penalty (logits, history, repetition_penalty, penalty_window);
348- if (enable_diversity_penalty)
349- apply_diversity_penalty (logits, common_phrases, diversity_penalty);
304+ if (enable_temperature) apply_temperature (logits, temperature);
305+ if (enable_repetition_penalty) apply_repetition_penalty (logits, history, repetition_penalty, penalty_window);
306+ if (enable_diversity_penalty) apply_diversity_penalty (logits, common_phrases, diversity_penalty);
350307
351308 if (enable_top_p_sampling)
352309 return faster_top_p_sampling (logits, top_p);
353310 else if (enable_top_k_sampling)
354311 return top_k_sampling (logits, top_k);
355- else
356- {
357- // 最大值
312+ else {
358313 float max_logit = *std::max_element (logits.begin (), logits.end ());
359- int max_index = std::distance (logits.begin (), std::max_element (logits.begin (), logits.end ()));
314+ int max_index = std::distance (logits.begin (), std::max_element (logits.begin (), logits.end ()));
360315 return max_index;
361316 }
362317 }
0 commit comments