44from typing import Optional
55import torch
66from transformers import AutoTokenizer
7- from whisper .tokenizer import Tokenizer
8-
97import tiktoken
108
119LANGUAGES = {
12- "en" : "english" ,
13- "zh" : "chinese" ,
14- "de" : "german" ,
15- "es" : "spanish" ,
16- "ru" : "russian" ,
17- "ko" : "korean" ,
18- "fr" : "french" ,
19- "ja" : "japanese" ,
20- "pt" : "portuguese" ,
21- "tr" : "turkish" ,
22- "pl" : "polish" ,
23- "ca" : "catalan" ,
24- "nl" : "dutch" ,
25- "ar" : "arabic" ,
26- "sv" : "swedish" ,
27- "it" : "italian" ,
28- "id" : "indonesian" ,
29- "hi" : "hindi" ,
30- "fi" : "finnish" ,
31- "vi" : "vietnamese" ,
32- "he" : "hebrew" ,
33- "uk" : "ukrainian" ,
34- "el" : "greek" ,
35- "ms" : "malay" ,
36- "cs" : "czech" ,
37- "ro" : "romanian" ,
38- "da" : "danish" ,
39- "hu" : "hungarian" ,
40- "ta" : "tamil" ,
41- "no" : "norwegian" ,
42- "th" : "thai" ,
43- "ur" : "urdu" ,
44- "hr" : "croatian" ,
45- "bg" : "bulgarian" ,
46- "lt" : "lithuanian" ,
47- "la" : "latin" ,
48- "mi" : "maori" ,
49- "ml" : "malayalam" ,
50- "cy" : "welsh" ,
51- "sk" : "slovak" ,
52- "te" : "telugu" ,
53- "fa" : "persian" ,
54- "lv" : "latvian" ,
55- "bn" : "bengali" ,
56- "sr" : "serbian" ,
57- "az" : "azerbaijani" ,
58- "sl" : "slovenian" ,
59- "kn" : "kannada" ,
60- "et" : "estonian" ,
61- "mk" : "macedonian" ,
62- "br" : "breton" ,
63- "eu" : "basque" ,
64- "is" : "icelandic" ,
65- "hy" : "armenian" ,
66- "ne" : "nepali" ,
67- "mn" : "mongolian" ,
68- "bs" : "bosnian" ,
69- "kk" : "kazakh" ,
70- "sq" : "albanian" ,
71- "sw" : "swahili" ,
72- "gl" : "galician" ,
73- "mr" : "marathi" ,
74- "pa" : "punjabi" ,
75- "si" : "sinhala" ,
76- "km" : "khmer" ,
77- "sn" : "shona" ,
78- "yo" : "yoruba" ,
79- "so" : "somali" ,
80- "af" : "afrikaans" ,
81- "oc" : "occitan" ,
82- "ka" : "georgian" ,
83- "be" : "belarusian" ,
84- "tg" : "tajik" ,
85- "sd" : "sindhi" ,
86- "gu" : "gujarati" ,
87- "am" : "amharic" ,
88- "yi" : "yiddish" ,
89- "lo" : "lao" ,
90- "uz" : "uzbek" ,
91- "fo" : "faroese" ,
92- "ht" : "haitian creole" ,
93- "ps" : "pashto" ,
94- "tk" : "turkmen" ,
95- "nn" : "nynorsk" ,
96- "mt" : "maltese" ,
97- "sa" : "sanskrit" ,
98- "lb" : "luxembourgish" ,
99- "my" : "myanmar" ,
100- "bo" : "tibetan" ,
101- "tl" : "tagalog" ,
102- "mg" : "malagasy" ,
103- "as" : "assamese" ,
104- "tt" : "tatar" ,
105- "haw" : "hawaiian" ,
106- "ln" : "lingala" ,
107- "ha" : "hausa" ,
108- "ba" : "bashkir" ,
109- "jw" : "javanese" ,
110- "su" : "sundanese" ,
111- "yue" : "cantonese" ,
112- "minnan" : "minnan" ,
113- "wuyu" : "wuyu" ,
114- "dialect" : "dialect" ,
115- "zh/en" : "zh/en" ,
116- "en/zh" : "en/zh" ,
10+ "en" : "english" , "zh" : "chinese" , "de" : "german" , "es" : "spanish" , "ru" : "russian" ,
11+ "ko" : "korean" , "fr" : "french" , "ja" : "japanese" , "pt" : "portuguese" , "tr" : "turkish" ,
12+ "pl" : "polish" , "ca" : "catalan" , "nl" : "dutch" , "ar" : "arabic" , "sv" : "swedish" , "it" : "italian" ,
13+ "id" : "indonesian" , "hi" : "hindi" , "fi" : "finnish" , "vi" : "vietnamese" , "he" : "hebrew" ,
14+ "uk" : "ukrainian" , "el" : "greek" , "ms" : "malay" , "cs" : "czech" , "ro" : "romanian" , "da" : "danish" ,
15+ "hu" : "hungarian" , "ta" : "tamil" , "no" : "norwegian" , "th" : "thai" , "ur" : "urdu" , "hr" : "croatian" ,
16+ "bg" : "bulgarian" , "lt" : "lithuanian" , "la" : "latin" , "mi" : "maori" , "ml" : "malayalam" , "cy" : "welsh" ,
17+ "sk" : "slovak" , "te" : "telugu" , "fa" : "persian" , "lv" : "latvian" , "bn" : "bengali" , "sr" : "serbian" ,
18+ "az" : "azerbaijani" , "sl" : "slovenian" , "kn" : "kannada" , "et" : "estonian" , "mk" : "macedonian" ,
19+ "br" : "breton" , "eu" : "basque" , "is" : "icelandic" , "hy" : "armenian" , "ne" : "nepali" , "mn" : "mongolian" ,
20+ "bs" : "bosnian" , "kk" : "kazakh" , "sq" : "albanian" , "sw" : "swahili" , "gl" : "galician" , "mr" : "marathi" ,
21+ "pa" : "punjabi" , "si" : "sinhala" , "km" : "khmer" , "sn" : "shona" , "yo" : "yoruba" , "so" : "somali" ,
22+ "af" : "afrikaans" , "oc" : "occitan" , "ka" : "georgian" , "be" : "belarusian" , "tg" : "tajik" ,
23+ "sd" : "sindhi" , "gu" : "gujarati" , "am" : "amharic" , "yi" : "yiddish" , "lo" : "lao" , "uz" : "uzbek" ,
24+ "fo" : "faroese" , "ht" : "haitian creole" , "ps" : "pashto" , "tk" : "turkmen" , "nn" : "nynorsk" ,
25+ "mt" : "maltese" , "sa" : "sanskrit" , "lb" : "luxembourgish" , "my" : "myanmar" , "bo" : "tibetan" ,
26+ "tl" : "tagalog" , "mg" : "malagasy" , "as" : "assamese" , "tt" : "tatar" , "haw" : "hawaiian" ,
27+ "ln" : "lingala" , "ha" : "hausa" , "ba" : "bashkir" , "jw" : "javanese" , "su" : "sundanese" ,
28+ "yue" : "cantonese" , "minnan" : "minnan" , "wuyu" : "wuyu" , "dialect" : "dialect" , "zh/en" : "zh/en" , "en/zh" : "en/zh"
11729}
11830
119- # language code lookup by name, with a few language aliases
12031TO_LANGUAGE_CODE = {
12132 ** {language : code for code , language in LANGUAGES .items ()},
122- "burmese" : "my" ,
123- "valencian" : "ca" ,
124- "flemish" : "nl" ,
125- "haitian" : "ht" ,
126- "letzeburgesch" : "lb" ,
127- "pushto" : "ps" ,
128- "panjabi" : "pa" ,
129- "moldavian" : "ro" ,
130- "moldovan" : "ro" ,
131- "sinhalese" : "si" ,
132- "castilian" : "es" ,
133- "mandarin" : "zh" ,
33+ "burmese" : "my" , "valencian" : "ca" , "flemish" : "nl" , "haitian" : "ht" , "letzeburgesch" : "lb" ,
34+ "pushto" : "ps" , "panjabi" : "pa" , "moldavian" : "ro" , "moldovan" : "ro" , "sinhalese" : "si" ,
35+ "castilian" : "es" , "mandarin" : "zh" ,
13436}
13537
13638AUDIO_EVENT = {
137- "ASR" : "ASR" ,
138- "AED" : "AED" ,
139- "SER" : "SER" ,
140- "Speech" : "Speech" ,
141- "/Speech" : "/Speech" ,
142- "BGM" : "BGM" ,
143- "/BGM" : "/BGM" ,
144- "Laughter" : "Laughter" ,
145- "/Laughter" : "/Laughter" ,
146- "Applause" : "Applause" ,
147- "/Applause" : "/Applause" ,
39+ "ASR" : "ASR" , "AED" : "AED" , "SER" : "SER" , "Speech" : "Speech" , "/Speech" : "/Speech" ,
40+ "BGM" : "BGM" , "/BGM" : "/BGM" , "Laughter" : "Laughter" , "/Laughter" : "/Laughter" ,
41+ "Applause" : "Applause" , "/Applause" : "/Applause" ,
14842}
14943
15044EMOTION = {
151- "HAPPY" : "HAPPY" ,
152- "SAD" : "SAD" ,
153- "ANGRY" : "ANGRY" ,
154- "NEUTRAL" : "NEUTRAL" ,
45+ "HAPPY" : "HAPPY" , "SAD" : "SAD" , "ANGRY" : "ANGRY" , "NEUTRAL" : "NEUTRAL" ,
15546}
15647
15748TTS_Vocal_Token = {
158- "TTS/B" : "TTS/B" ,
159- "TTS/O" : "TTS/O" ,
160- "TTS/Q" : "TTS/Q" ,
161- "TTS/A" : "TTS/A" ,
162- "TTS/CO" : "TTS/CO" ,
163- "TTS/CL" : "TTS/CL" ,
164- "TTS/H" : "TTS/H" ,
165- ** {f"TTS/SP{ i :02d} " : f"TTS/SP{ i :02d} " for i in range (1 , 14 )}
49+ "TTS/B" : "TTS/B" , "TTS/O" : "TTS/O" , "TTS/Q" : "TTS/Q" , "TTS/A" : "TTS/A" , "TTS/CO" : "TTS/CO" ,
50+ "TTS/CL" : "TTS/CL" , "TTS/H" : "TTS/H" , ** {f"TTS/SP{ i :02d} " : f"TTS/SP{ i :02d} " for i in range (1 , 14 )}
16651}
16752
168-
53+ # ===== 构造 Encoding =====
16954@lru_cache (maxsize = None )
17055def get_encoding (name : str = "gpt2" , num_languages : int = 99 ):
17156 vocab_path = os .path .join (os .path .dirname (__file__ ), "assets" , f"{ name } .tiktoken" )
@@ -175,28 +60,20 @@ def get_encoding(name: str = "gpt2", num_languages: int = 99):
17560 }
17661 n_vocab = len (ranks )
17762 special_tokens = {}
178-
17963 specials = [
180- "<|endoftext|>" ,
181- "<|startoftranscript|>" ,
64+ "<|endoftext|>" , "<|startoftranscript|>" ,
18265 * [f"<|{ lang } |>" for lang in list (LANGUAGES .keys ())[:num_languages ]],
18366 * [f"<|{ audio_event } |>" for audio_event in list (AUDIO_EVENT .keys ())],
18467 * [f"<|{ emotion } |>" for emotion in list (EMOTION .keys ())],
185- "<|translate|>" ,
186- "<|transcribe|>" ,
187- "<|startoflm|>" ,
188- "<|startofprev|>" ,
189- "<|nospeech|>" ,
190- "<|notimestamps|>" ,
191- * [f"<|SPECIAL_TOKEN_{ i } |>" for i in range (1 , 31 )], # register special tokens for ASR
192- * [f"<|{ tts } |>" for tts in list (TTS_Vocal_Token .keys ())], # register special tokens for TTS
68+ "<|translate|>" , "<|transcribe|>" , "<|startoflm|>" , "<|startofprev|>" ,
69+ "<|nospeech|>" , "<|notimestamps|>" ,
70+ * [f"<|SPECIAL_TOKEN_{ i } |>" for i in range (1 , 31 )],
71+ * [f"<|{ tts } |>" for tts in list (TTS_Vocal_Token .keys ())],
19372 * [f"<|{ i * 0.02 :.2f} |>" for i in range (1501 )],
19473 ]
195-
19674 for token in specials :
19775 special_tokens [token ] = n_vocab
19876 n_vocab += 1
199-
20077 return tiktoken .Encoding (
20178 name = os .path .basename (vocab_path ),
20279 explicit_n_vocab = n_vocab ,
@@ -205,23 +82,32 @@ def get_encoding(name: str = "gpt2", num_languages: int = 99):
20582 special_tokens = special_tokens ,
20683 )
20784
85+ class SimpleTokenizer :
86+ def __init__ (self , encoding , num_languages : int = 99 , language : Optional [str ] = None , task : Optional [str ] = None ):
87+ self .encoding = encoding
88+ self .num_languages = num_languages
89+ self .language = language
90+ self .task = task
91+ def encode (self , text : str ):
92+ return self .encoding .encode (text )
93+ def decode (self , tokens : list ):
94+ return self .encoding .decode (tokens )
20895
20996@lru_cache (maxsize = None )
21097def get_tokenizer (
21198 multilingual : bool ,
21299 * ,
213100 num_languages : int = 99 ,
214101 language : Optional [str ] = None ,
215- task : Optional [str ] = None , # Literal["transcribe", "translate", None]
216- ) -> Tokenizer :
102+ task : Optional [str ] = None ,
103+ ) -> SimpleTokenizer :
217104 if language is not None :
218105 language = language .lower ()
219106 if language not in LANGUAGES :
220107 if language in TO_LANGUAGE_CODE :
221108 language = TO_LANGUAGE_CODE [language ]
222109 else :
223110 raise ValueError (f"Unsupported language: { language } " )
224-
225111 if multilingual :
226112 encoding_name = "multilingual_zh_ja_yue_char_del"
227113 language = language or "en"
@@ -230,18 +116,12 @@ def get_tokenizer(
230116 encoding_name = "gpt2"
231117 language = None
232118 task = None
233-
234119 encoding = get_encoding (name = encoding_name , num_languages = num_languages )
235-
236- return Tokenizer (
237- encoding = encoding , num_languages = num_languages , language = language , task = task
238- )
239-
120+ return SimpleTokenizer (encoding = encoding , num_languages = num_languages , language = language , task = task )
240121
241122class QwenTokenizer ():
242123 def __init__ (self , token_path , skip_special_tokens = True ):
243124 super ().__init__ ()
244- # NOTE: non-chat model, all these special tokens keep randomly initialized.
245125 special_tokens = {
246126 'eos_token' : '<|endoftext|>' ,
247127 'pad_token' : '<|endoftext|>' ,
@@ -259,21 +139,13 @@ def __init__(self, token_path, skip_special_tokens=True):
259139 self .tokenizer = AutoTokenizer .from_pretrained (token_path )
260140 self .tokenizer .add_special_tokens (special_tokens )
261141 self .skip_special_tokens = skip_special_tokens
262-
263142 def encode (self , text , ** kwargs ):
264143 tokens = self .tokenizer ([text ], return_tensors = "pt" )
265- tokens = tokens ["input_ids" ][0 ].cpu ().tolist ()
266- return tokens
267-
144+ return tokens ["input_ids" ][0 ].cpu ().tolist ()
268145 def decode (self , tokens ):
269146 tokens = torch .tensor (tokens , dtype = torch .int64 )
270- text = self .tokenizer .batch_decode ([tokens ], skip_special_tokens = self .skip_special_tokens )[0 ]
271- return text
272-
147+ return self .tokenizer .batch_decode ([tokens ], skip_special_tokens = self .skip_special_tokens )[0 ]
273148
274149@lru_cache (maxsize = None )
275- def get_qwen_tokenizer (
276- token_path : str ,
277- skip_special_tokens : bool
278- ) -> QwenTokenizer :
279- return QwenTokenizer (token_path = token_path , skip_special_tokens = skip_special_tokens )
150+ def get_qwen_tokenizer (token_path : str , skip_special_tokens : bool ) -> QwenTokenizer :
151+ return QwenTokenizer (token_path = token_path , skip_special_tokens = skip_special_tokens )
0 commit comments