@@ -20,69 +20,67 @@ typedef std::function<void(int *, int, const char *, float, void *)> LLMRuningCa
2020
2121struct LLMAttrType {
2222 std::string system_prompt;
23+
2324 std::string template_filename_axmodel = " tinyllama-int8/tinyllama_l%d.axmodel" ;
25+ std::string post_config_path = " post_config.json" ;
2426 int axmodel_num = 22 ;
2527
26- std::string filename_post_axmodel = " tinyllama-int8/tinyllama_post.axmodel" ;
27- std::string filename_image_encoder_axmodel = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
28- std::string filename_vpm_encoder_axmodel = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
28+ std::string filename_image_encoder_axmodel = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
29+ std::string filename_vpm_encoder_axmodel = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
2930 std::string filename_vpm_resampler_axmodedl = " minicpmv/vpm_resampler_version0_fp16.axmodel" ;
3031
31- int image_encoder_width = 448 ;
32- int image_encoder_height = 448 ;
33- int vpm_width = 280 ;
34- int vpm_height = 280 ;
35- bool b_vpm_two_stage = false ;
32+ int image_encoder_width = 448 ;
33+ int image_encoder_height = 448 ;
34+ int vpm_width = 280 ;
35+ int vpm_height = 280 ;
36+ bool b_vpm_two_stage = false ;
37+ int IMAGE_CONTEXT_TOKEN = 151667 ;
38+ int IMAGE_START_TOKEN = 151665 ;
39+ int IMAGE_ENCODER_INPUT_NCHW = -1 ;
40+ int IMAGE_ENCODER_OUTPUT_BF16 = -1 ;
3641
3742 int prefill_token_num = 96 ;
3843 int prefill_max_token_num = 512 ;
39- std::vector<int > prefill_max_kv_cache_num_grp;
40- int precompute_len = 0 ;
41- int prefill_grpid = -1 ;
44+
45+ std::string filename_post_axmodel = " tinyllama-int8/tinyllama_post.axmodel" ;
4246
4347 TokenizerType tokenizer_type = TKT_LLaMa;
4448 std::string filename_tokenizer_model = " tokenizer.model" ;
4549 std::string url_tokenizer_model;
46- bool b_bos = true , b_eos = false ;
50+ bool b_bos = true ;
51+ bool b_eos = false ;
4752 std::string filename_tokens_embed = " tinyllama.model.embed_tokens.weight.bfloat16.bin" ;
4853 int tokens_embed_num = 32000 ;
4954 int img_token_id = 151667 ;
5055 int tokens_embed_size = 2048 ;
5156
5257 int max_token_len = 127 ;
53-
5458 int kv_cache_num = 1024 ;
5559 int kv_cache_size = 256 ;
5660
61+ int precompute_len = 0 ;
62+ std::vector<int > prefill_max_kv_cache_num_grp;
63+ int prefill_grpid = -1 ;
64+
5765 bool enable_temperature = false ;
5866 float temperature = 0 .7f ;
5967
6068 bool enable_top_p_sampling = false ;
6169 float top_p = 0 .7f ;
6270
63- bool enable_top_k_sampling = false ;
64- int top_k = 50 ;
71+ bool enable_top_k_sampling = true ;
72+ int top_k = 10 ;
6573
6674 bool enable_repetition_penalty = false ;
6775 float repetition_penalty = 1 .2f ;
6876 int penalty_window = 50 ;
6977
7078 bool b_use_mmap_load_embed = false ;
7179 bool b_dynamic_load_axmodel_layer = false ;
80+ bool b_use_mmap_load_layer = true ;
7281
73- bool b_use_mmap_load_layer = true ;
74-
75- bool b_use_topk = false ;
76- std::string post_config_path = " post_config.json" ;
77-
78- // bool b_live_print = true;
7982 LLMRuningCallback runing_callback = nullptr ;
8083 void *reserve = nullptr ;
81-
82- int IMAGE_CONTEXT_TOKEN = 151667 ;
83- int IMAGE_START_TOKEN = 151665 ;
84- int IMAGE_ENCODER_INPUT_NCHW = -1 ;
85- int IMAGE_ENCODER_OUTPUT_BF16 = -1 ;
8684};
8785
8886class LLM {
@@ -142,7 +140,6 @@ class LLM {
142140 return false ;
143141 }
144142 update_cqdm (&cqdm, 1 , " count" , " embed_selector init ok" );
145-
146143 llama_layers.resize (attr.axmodel_num );
147144
148145 char axmodel_path[1024 ];
@@ -241,13 +238,34 @@ class LLM {
241238
242239 _attr.prefill_token_num = llama_layers[0 ].layer .get_input (prefill_grpid, " indices" ).vShape [1 ];
243240 ALOGI (" prefill_token_num : %d" , _attr.prefill_token_num );
244-
245241 ALOGI (" vpm_height : %d,vpm_width : %d" , _attr.vpm_height , _attr.vpm_width );
246242 }
247243 if (attr.b_dynamic_load_axmodel_layer ) {
248244 auto &layer = llama_layers[0 ];
249245 layer.layer .deinit ();
250246 }
247+ nlohmann::json dynamic_config;
248+
249+ dynamic_config[" enable_temperature" ] = _attr.enable_temperature ;
250+ dynamic_config[" temperature" ] = _attr.temperature ;
251+
252+ dynamic_config[" enable_repetition_penalty" ] = _attr.enable_repetition_penalty ;
253+ dynamic_config[" repetition_penalty" ] = _attr.repetition_penalty ;
254+ dynamic_config[" penalty_window" ] = _attr.penalty_window ;
255+
256+ dynamic_config[" enable_top_p_sampling" ] = _attr.enable_top_p_sampling ;
257+ dynamic_config[" top_p" ] = _attr.top_p ;
258+
259+ dynamic_config[" enable_top_k_sampling" ] = _attr.enable_top_k_sampling ;
260+ dynamic_config[" top_k" ] = _attr.top_k ;
261+
262+ if (!postprocess.load_config (attr.post_config_path )) {
263+ ALOGW (" load postprocess config(%s) failed" , attr.post_config_path .c_str ());
264+ }
265+
266+ if (!postprocess.load_config (dynamic_config)) {
267+ ALOGW (" load postprocess config(%s) failed" , dynamic_config.dump (4 ).c_str ());
268+ }
251269
252270 // Reset();
253271 ALOGI (" LLM init ok" );
@@ -483,19 +501,15 @@ class LLM {
483501 auto &input = llama_post.get_input (" input" );
484502 memcpy (input.pVirAddr , embed.data (), embed.size () * sizeof (unsigned short ));
485503 llama_post.inference ();
504+
486505 int max_index;
487- if (_attr.b_use_topk ) {
488- AX_SYS_MinvalidateCache (llama_post.get_output (" indices" ).phyAddr ,
489- llama_post.get_output (" indices" ).pVirAddr ,
490- llama_post.get_output (" indices" ).nSize );
491- max_index = *(int *)llama_post.get_output (" indices" ).pVirAddr ;
492- } else {
493- auto &output_post = llama_post.get_output (" output" );
494- AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
495- unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
496- float max_val = -MAXFLOAT;
497- max_index = post_process (postprocess, post_out, _attr.tokens_embed_num , token_ids, &max_val);
498- }
506+
507+ auto &output_post = llama_post.get_output (" output" );
508+ AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
509+ unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
510+ float max_val = -MAXFLOAT;
511+ max_index = post_process (postprocess, post_out, _attr.tokens_embed_num , token_ids, &max_val);
512+
499513 next_token = max_index;
500514
501515 token_ids.push_back (max_index);
@@ -574,18 +588,13 @@ class LLM {
574588 memcpy (input.pVirAddr , embed.data (), embed.size () * sizeof (unsigned short ));
575589 llama_post.inference ();
576590 int max_index;
577- if (_attr.b_use_topk ) {
578- AX_SYS_MinvalidateCache (llama_post.get_output (" indices" ).phyAddr ,
579- llama_post.get_output (" indices" ).pVirAddr ,
580- llama_post.get_output (" indices" ).nSize );
581- max_index = *(int *)llama_post.get_output (" indices" ).pVirAddr ;
582- } else {
583- auto &output_post = llama_post.get_output (" output" );
584- AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
585- unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
586- float max_val = -MAXFLOAT;
587- max_index = post_process (postprocess, post_out, _attr.tokens_embed_num , token_ids, &max_val);
588- }
591+
592+ auto &output_post = llama_post.get_output (" output" );
593+ AX_SYS_MinvalidateCache (output_post.phyAddr , output_post.pVirAddr , output_post.nSize );
594+ unsigned short *post_out = (unsigned short *)output_post.pVirAddr ;
595+ float max_val = -MAXFLOAT;
596+ max_index = post_process (postprocess, post_out, _attr.tokens_embed_num , token_ids, &max_val);
597+
589598 next_token = max_index;
590599
591600 if (tokenizer->isEnd (max_index)) {
0 commit comments