@@ -460,14 +460,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
460460 return embed_out
461461
462462class SDTokenizer :
463- def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = 'clip_l' , tokenizer_class = CLIPTokenizer , has_start_token = True , has_end_token = True , pad_to_max_length = True , min_length = None , pad_token = None , end_token = None , min_padding = None , tokenizer_data = {}, tokenizer_args = {}):
463+ def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = 'clip_l' , tokenizer_class = CLIPTokenizer , has_start_token = True , has_end_token = True , pad_to_max_length = True , min_length = None , pad_token = None , end_token = None , min_padding = None , pad_left = False , tokenizer_data = {}, tokenizer_args = {}):
464464 if tokenizer_path is None :
465465 tokenizer_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "sd1_tokenizer" )
466466 self .tokenizer = tokenizer_class .from_pretrained (tokenizer_path , ** tokenizer_args )
467467 self .max_length = tokenizer_data .get ("{}_max_length" .format (embedding_key ), max_length )
468468 self .min_length = tokenizer_data .get ("{}_min_length" .format (embedding_key ), min_length )
469469 self .end_token = None
470470 self .min_padding = min_padding
471+ self .pad_left = pad_left
471472
472473 empty = self .tokenizer ('' )["input_ids" ]
473474 self .tokenizer_adds_end_token = has_end_token
@@ -522,6 +523,12 @@ def _try_get_embedding(self, embedding_name:str):
522523 return (embed , "{} {}" .format (embedding_name [len (stripped ):], leftover ))
523524 return (embed , leftover )
524525
526+ def pad_tokens (self , tokens , amount ):
527+ if self .pad_left :
528+ for i in range (amount ):
529+ tokens .insert (0 , (self .pad_token , 1.0 , 0 ))
530+ else :
531+ tokens .extend ([(self .pad_token , 1.0 , 0 )] * amount )
525532
526533 def tokenize_with_weights (self , text :str , return_word_ids = False , tokenizer_options = {}, ** kwargs ):
527534 '''
@@ -600,7 +607,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
600607 if self .end_token is not None :
601608 batch .append ((self .end_token , 1.0 , 0 ))
602609 if self .pad_to_max_length :
603- batch . extend ([( self .pad_token , 1.0 , 0 )] * ( remaining_length ) )
610+ self .pad_tokens ( batch , remaining_length )
604611 #start new batch
605612 batch = []
606613 if self .start_token is not None :
@@ -614,11 +621,11 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
614621 if self .end_token is not None :
615622 batch .append ((self .end_token , 1.0 , 0 ))
616623 if min_padding is not None :
617- batch . extend ([( self .pad_token , 1.0 , 0 )] * min_padding )
624+ self .pad_tokens ( batch , min_padding )
618625 if self .pad_to_max_length and len (batch ) < self .max_length :
619- batch . extend ([( self .pad_token , 1.0 , 0 )] * ( self .max_length - len (batch ) ))
626+ self .pad_tokens ( batch , self .max_length - len (batch ))
620627 if min_length is not None and len (batch ) < min_length :
621- batch . extend ([( self .pad_token , 1.0 , 0 )] * ( min_length - len (batch ) ))
628+ self .pad_tokens ( batch , min_length - len (batch ))
622629
623630 if not return_word_ids :
624631 batched_tokens = [[(t , w ) for t , w ,_ in x ] for x in batched_tokens ]
0 commit comments