Skip to content

Commit 4e3d7f3

Browse files
author
LittleMouse
committed
[update] fix llm generate bug
1 parent f9de469 commit 4e3d7f3

File tree

2 files changed

+14
-32
lines changed

2 files changed

+14
-32
lines changed

projects/llm_framework/main_cosy_voice/src/main.cpp

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -411,45 +411,42 @@ class llm_task {
411411
timer time_total;
412412
time_total.start();
413413
try {
414-
auto llm_thread_func = [this, &text, &prompt_text_embeds, &prompt_speech_embeds]() {
415-
lLaMa_->Run(text, prompt_text_embeds, prompt_speech_embeds, g_token_buffer, g_buffer_mutex, g_buffer_cv,
416-
g_llm_finished);
414+
int llm_ret = 0;
415+
auto llm_thread_func = [this, &text, &prompt_text_embeds, &prompt_speech_embeds, &llm_ret]() {
416+
llm_ret = lLaMa_->Run(text, prompt_text_embeds, prompt_speech_embeds, g_token_buffer, g_buffer_mutex,
417+
g_buffer_cv, g_llm_finished);
417418
};
418-
419419
std::thread llm_thread(llm_thread_func);
420-
421-
int token_offset = 0;
420+
llm_thread.detach();
422421
int prompt_token_len = prompt_speech_embeds_flow.size() / lToken2Wav._attr.flow_embed_size;
423422
if (prompt_token_len < 75) {
424423
SLOGE("Error, prompt speech token len %d < 75", prompt_token_len);
425424
if (llm_thread.joinable()) llm_thread.join();
426425
return -1;
427426
}
427+
if (llm_ret == -1) {
428+
return llm_ret;
429+
}
428430
int prompt_token_align_len = 75;
429-
430431
std::vector<float> prompt_speech_embeds_flow1;
431432
prompt_speech_embeds_flow1.insert(prompt_speech_embeds_flow1.begin(), prompt_speech_embeds_flow.begin(),
432433
prompt_speech_embeds_flow.begin() + prompt_token_align_len * 512);
433-
434434
std::vector<float> prompt_feat1;
435435
prompt_feat1.insert(prompt_feat1.begin(), prompt_feat.begin(),
436436
prompt_feat.begin() + prompt_token_align_len * 2 * 80);
437-
438437
int promot_token_pad = 0;
439438
int this_token_hop_len;
440-
int i = 0;
439+
int token_offset = 0;
440+
int i = 0;
441441
while (true) {
442442
this_token_hop_len = (token_offset == 0) ? lToken2Wav._attr.token_hop_len + promot_token_pad
443443
: lToken2Wav._attr.token_hop_len;
444-
445444
std::unique_lock<std::mutex> lock(g_buffer_mutex);
446-
447445
g_buffer_cv.wait(lock, [&] {
448446
return (g_token_buffer.size() - token_offset >=
449447
this_token_hop_len + lToken2Wav._attr.pre_lookahead_len) ||
450448
g_llm_finished.load() || g_stop.load();
451449
});
452-
453450
if (g_stop) {
454451
lock.unlock();
455452
break;
@@ -460,9 +457,7 @@ class llm_task {
460457
lToken2Wav._attr.max_infer_chunk_num - 1) *
461458
lToken2Wav._attr.token_hop_len;
462459
int end = token_offset + this_token_hop_len + lToken2Wav._attr.pre_lookahead_len;
463-
464460
token.insert(token.end(), g_token_buffer.begin() + start, g_token_buffer.begin() + end);
465-
466461
lock.unlock();
467462
auto speech = lToken2Wav.infer(token, prompt_speech_embeds_flow1, prompt_feat1, spk_embeds,
468463
token_offset, false);
@@ -481,7 +476,6 @@ class llm_task {
481476
if (val < -1.0f) val = -1.0f;
482477
wav_pcm_data.push_back(static_cast<int16_t>(val * 32767.0f));
483478
}
484-
485479
if (out_callback_) {
486480
out_callback_(std::string(reinterpret_cast<char *>(wav_pcm_data.data()),
487481
wav_pcm_data.size() * sizeof(int16_t)),
@@ -496,10 +490,6 @@ class llm_task {
496490
}
497491
}
498492

499-
if (llm_thread.joinable()) {
500-
llm_thread.join();
501-
}
502-
503493
if (g_stop) {
504494
g_token_buffer.erase(g_token_buffer.begin(), g_token_buffer.end());
505495
return 1;
@@ -518,7 +508,6 @@ class llm_task {
518508
std::vector<float> resampled_pcm(static_cast<size_t>(speech.size() * src_ratio + 1));
519509
int resampled_len = 0;
520510
resample_audio(speech.data(), speech.size(), resampled_pcm.data(), &resampled_len, src_ratio);
521-
522511
std::vector<int16_t> wav_pcm_data;
523512
wav_pcm_data.reserve(resampled_len);
524513
for (int i = 0; i < resampled_len; i++) {
@@ -538,7 +527,6 @@ class llm_task {
538527
std::vector<float> resampled_pcm(static_cast<size_t>(output.size() * src_ratio + 1));
539528
int resampled_len = 0;
540529
resample_audio(output.data(), output.size(), resampled_pcm.data(), &resampled_len, src_ratio);
541-
542530
std::vector<int16_t> wav_pcm_data_full;
543531
wav_pcm_data_full.reserve(resampled_len);
544532
for (int i = 0; i < resampled_len; i++) {
@@ -547,7 +535,6 @@ class llm_task {
547535
if (val < -1.0f) val = -1.0f;
548536
wav_pcm_data_full.push_back(static_cast<int16_t>(val * 32767.0f));
549537
}
550-
551538
std::string wav_path;
552539
if (mode_config_.output_path.empty()) {
553540
wav_path = generateFilename("/tmp");
@@ -561,14 +548,12 @@ class llm_task {
561548
}
562549
saveVectorAsWavFloat(resampled_pcm, wav_path, mode_config_.audio_rate, 1);
563550
}
564-
565551
SLOGI("tts total use time: %.3f s", time_total.cost() / 1000);
566552
reset();
567553
} catch (const std::exception &e) {
568554
std::cerr << "Error in pipeline: " << e.what() << std::endl;
569555
return 1;
570556
}
571-
572557
return 0;
573558
}
574559

@@ -599,12 +584,7 @@ class llm_task {
599584
void inference(const std::string &msg)
600585
{
601586
try {
602-
// std::string out = lLaMa_->Run(prompt_complete(msg));
603-
// if (out_callback_) out_callback_(out, true);
604587
tts(msg, prompt_text_embeds, prompt_speech_embeds, prompt_feat, prompt_speech_embeds_flow, spk_embeds);
605-
std::string out = "finish";
606-
if (out_callback_) out_callback_(out, true);
607-
608588
} catch (...) {
609589
SLOGW("lLaMa_->Run have error!");
610590
}

projects/llm_framework/main_cosy_voice/src/runner/LLM.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class LLM {
342342
{
343343
std::vector<unsigned short> text_embed;
344344
std::vector<std::vector<int>> position_ids;
345-
Encode(text_embed, position_ids, input_str, prompt_text_embeds, prompt_speech_embeds);
345+
if (Encode(text_embed, position_ids, input_str, prompt_text_embeds, prompt_speech_embeds)) return -1;
346346
return Run(text_embed, position_ids, token_buffer, buffer_mutex, buffer_cv, llm_finished);
347347
}
348348

@@ -560,7 +560,9 @@ class LLM {
560560
if (b_stop) {
561561
break;
562562
}
563-
563+
if (indices >= _attr.kv_cache_num) {
564+
break;
565+
}
564566
speech_embed_selector.getByIndex(next_token, embed.data());
565567
memcpy((void *)llama_layers[0].layer.get_input(decode_grpid, "input").pVirAddr, embed.data(),
566568
llama_layers[0].layer.get_input(decode_grpid, "input").nSize);

0 commit comments

Comments
 (0)