|
86 | 86 | "transcription", |
87 | 87 | "draft", |
88 | 88 | ] |
89 | | -TokenizerMode = Literal["auto", "hf", "slow", "mistral", "custom"] |
| 89 | +TokenizerMode = Literal["auto", "hf", "slow", "mistral"] |
90 | 90 | ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] |
91 | 91 | LogprobsMode = Literal[ |
92 | 92 | "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" |
@@ -137,13 +137,13 @@ class ModelConfig: |
137 | 137 | tokenizer: SkipValidation[str] = None # type: ignore |
138 | 138 | """Name or path of the Hugging Face tokenizer to use. If unspecified, model |
139 | 139 | name or path will be used.""" |
140 | | - tokenizer_mode: TokenizerMode = "auto" |
| 140 | + tokenizer_mode: TokenizerMode | str = "auto" |
141 | 141 | """Tokenizer mode:\n |
142 | 142 | - "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n |
143 | 143 | - "hf" will use the fast tokenizer if available.\n |
144 | 144 | - "slow" will always use the slow tokenizer.\n |
145 | 145 | - "mistral" will always use the tokenizer from `mistral_common`.\n |
146 | | - - "custom" will use --tokenizer to select the preregistered tokenizer.""" |
| 146 | + - Other custom values can be supported via plugins.""" |
147 | 147 | trust_remote_code: bool = False |
148 | 148 | """Trust remote code (e.g., from HuggingFace) when downloading the model |
149 | 149 | and tokenizer.""" |
@@ -708,16 +708,17 @@ def _task_to_convert(task: TaskOption) -> ConvertType: |
708 | 708 | # can be correctly capped to sliding window size |
709 | 709 | self.hf_text_config.sliding_window = None |
710 | 710 |
|
711 | | - if not self.skip_tokenizer_init: |
712 | | - self._verify_tokenizer_mode() |
713 | | - |
714 | 711 | # Avoid running try_verify_and_update_config multiple times |
715 | 712 | self.config_updated = False |
716 | 713 |
|
717 | 714 | self._verify_quantization() |
718 | 715 | self._verify_cuda_graph() |
719 | 716 | self._verify_bnb_config() |
720 | 717 |
|
| 718 | + @field_validator("tokenizer_mode", mode="after") |
| 719 | + def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: |
| 720 | + return tokenizer_mode.lower() |
| 721 | + |
721 | 722 | @field_validator("quantization", mode="before") |
722 | 723 | @classmethod |
723 | 724 | def validate_quantization_before(cls, value: Any) -> Any: |
@@ -829,15 +830,6 @@ def _get_encoder_config(self): |
829 | 830 | model, _ = split_remote_gguf(model) |
830 | 831 | return get_sentence_transformer_tokenizer_config(model, self.revision) |
831 | 832 |
|
832 | | - def _verify_tokenizer_mode(self) -> None: |
833 | | - tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) |
834 | | - if tokenizer_mode not in get_args(TokenizerMode): |
835 | | - raise ValueError( |
836 | | - f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " |
837 | | - f"one of {get_args(TokenizerMode)}." |
838 | | - ) |
839 | | - self.tokenizer_mode = tokenizer_mode |
840 | | - |
841 | 833 | def _get_default_runner_type( |
842 | 834 | self, |
843 | 835 | architectures: list[str], |
|
0 commit comments