2020#include " timer.hpp"
2121// #include "opencv2/opencv.hpp"
2222#include " ax_sys_api.h"
23+ #include " MNN/MNNDefine.h"
24+ #include " MNN/MNNForwardType.h"
25+ #include " MNN/Interpreter.hpp"
2326
2427class Token2Wav
2528{
@@ -44,8 +47,15 @@ class Token2Wav
4447 ax_runner_ax650 flow_estimator_250;
4548 ax_runner_ax650 flow_estimator_300;
4649
47- ax_runner_ax650 hift_50_first;
48- ax_runner_ax650 hift_58;
50+ ax_runner_ax650 hift_p2_50_first;
51+ ax_runner_ax650 hift_p2_58;
52+
53+ std::shared_ptr<MNN::Interpreter> hift_p1_50_first = nullptr ;
54+ std::shared_ptr<MNN::Interpreter> hift_p1_58 = nullptr ;
55+
56+ MNN::Session * sess_hift_p1_50_first = nullptr ;
57+ MNN::Session * sess_hift_p1_58 = nullptr ;
58+
4959
5060 std::vector<float > rand_noise;
5161 std::vector<float > t_span;
@@ -161,20 +171,44 @@ class Token2Wav
161171 return false ;
162172 }
163173
164- ret = hift_50_first .init ((model_dir+" /hift_50_first .axmodel" ).c_str (), false );
174+ ret = hift_p2_50_first .init ((model_dir+" /hift_p2_50_first .axmodel" ).c_str (), false );
165175 if (ret != 0 )
166176 {
167- ALOGE (" init axmodel(%s) failed" , (model_dir+" /hift_50_first .axmodel" ).c_str ());
177+ ALOGE (" init axmodel(%s) failed" , (model_dir+" /hift_p2_50_first .axmodel" ).c_str ());
168178 return false ;
169179 }
170180
171- ret = hift_58 .init ((model_dir+" /hift_58 .axmodel" ).c_str (), false );
181+ ret = hift_p2_58 .init ((model_dir+" /hift_p2_58 .axmodel" ).c_str (), false );
172182 if (ret != 0 )
173183 {
174- ALOGE (" init axmodel(%s) failed" , (model_dir+" /hift_58 .axmodel" ).c_str ());
184+ ALOGE (" init axmodel(%s) failed" , (model_dir+" /hift_p2_58 .axmodel" ).c_str ());
175185 return false ;
176186 }
177187
188+ MNN::ScheduleConfig config;
189+ config.numThread = 2 ;
190+ config.type = static_cast <MNNForwardType>(MNN_FORWARD_CPU);
191+ MNN::BackendConfig backendConfig;
192+ backendConfig.precision = (MNN::BackendConfig::PrecisionMode)1 ;
193+ config.backendConfig = &backendConfig;
194+
195+ hift_p1_50_first = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile ( (model_dir+" /hift_p1_50_first.mnn" ).c_str () ));
196+ if (nullptr == hift_p1_50_first)
197+ {
198+ ALOGE (" init mnn model(%s) failed" , (model_dir+" /hift_p1_50_first.mnn" ).c_str ());
199+ return false ;
200+ }
201+ sess_hift_p1_50_first = hift_p1_50_first->createSession (config);
202+
203+ hift_p1_58 = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile ( (model_dir+" /hift_p1_58.mnn" ).c_str () ));
204+ if (nullptr == hift_p1_58)
205+ {
206+ ALOGE (" init mnn model(%s) failed" , (model_dir+" /hift_p1_58.mnn" ).c_str () );
207+ return false ;
208+ }
209+
210+ sess_hift_p1_58 = hift_p1_58->createSession (config);
211+
178212 ALOGI (" Token2Wav init ok" );
179213 return true ;
180214 }
@@ -188,8 +222,8 @@ class Token2Wav
188222 flow_estimator_200.release ();
189223 flow_estimator_250.release ();
190224 flow_estimator_300.release ();
191- hift_50_first .release ();
192- hift_58 .release ();
225+ hift_p2_50_first .release ();
226+ hift_p2_58 .release ();
193227 flow_embed_selector.Deinit ();
194228 }
195229
@@ -318,39 +352,66 @@ class Token2Wav
318352 int infer_hift (std::vector<float > &mel, std::vector<float > &cache_source,
319353 std::vector<float > & tts_speech, std::vector<float > & tts_source)
320354 {
321- ax_runner_ax650 * model;
355+ std::shared_ptr<MNN::Interpreter> model_p1;
356+ MNN::Session * sess_p1;
357+ ax_runner_ax650 * model_p2;
322358 int len = mel.size ()/(80 );
323359
324360 if (len == 50 && cache_source.empty ())
325361 {
326- model = &hift_50_first;
362+ model_p1 = hift_p1_50_first;
363+ sess_p1 = sess_hift_p1_50_first;
364+ model_p2 = &hift_p2_50_first;
327365 }else if (len == 58 && !cache_source.empty ())
328366 {
329- model = &hift_58;
367+ model_p1 = hift_p1_58;
368+ sess_p1 = sess_hift_p1_58;
369+ model_p2 = &hift_p2_58;
330370 }else
331371 {
332372 ALOGE (" invalid size: %d" , len);
333373 return -1 ;
334374 }
335375
336- void * p = model->get_input (" mel" ).pVirAddr ;
376+ std::vector<int > dims{1 , 80 , len};
377+ auto tensor = MNN::Tensor::create<float >(dims, NULL , MNN::Tensor::CAFFE);
378+ auto p_tensor = tensor->host <float >();
379+ auto size = tensor->size ();
380+ std::memcpy (p_tensor, mel.data (), size);
381+
382+ auto inputTensor = model_p1->getSessionInput (sess_p1, nullptr );
383+ inputTensor->copyFromHostTensor (tensor);
384+
385+ model_p1->runSession (sess_p1);
386+
387+ MNN::Tensor *p_out = model_p1->getSessionOutput (sess_p1, " s" );
388+ MNN::Tensor out_host (p_out, p_out->getDimensionType ());
389+ p_out->copyToHostTensor (&out_host);
390+
391+ auto p_s = out_host.host <float >();
392+
393+ void * p = model_p2->get_input (" s" ).pVirAddr ;
394+ memcpy (p, p_s, len * 480 * sizeof (float ));
395+
396+ p = model_p2->get_input (" mel" ).pVirAddr ;
337397 memcpy (p, mel.data (), mel.size () * sizeof (float ));
398+
338399 if (!cache_source.empty ())
339400 {
340- p = model ->get_input (" hift_cache_source" ).pVirAddr ;
401+ p = model_p2 ->get_input (" hift_cache_source" ).pVirAddr ;
341402 memcpy (p, cache_source.data (), cache_source.size () * sizeof (float ));
342403 }
343-
344- model ->inference ();
345-
346- auto &output_speech = model ->get_output (" audio" );
404+
405+ model_p2 ->inference ();
406+
407+ auto &output_speech = model_p2 ->get_output (" audio" );
347408 if (tts_speech.empty () || tts_speech.size () != output_speech.nSize / sizeof (float ))
348409 {
349410 tts_speech.resize (output_speech.nSize / sizeof (float ));
350411 }
351412 memcpy (tts_speech.data (), output_speech.pVirAddr , output_speech.nSize );
352413
353- auto &output_source = model ->get_output (" x " );
414+ auto &output_source = model_p2 ->get_output (1 );
354415 if (tts_source.empty () || tts_source.size () != output_source.nSize / sizeof (float ))
355416 {
356417 tts_source.resize (output_source.nSize / sizeof (float ));
0 commit comments