Skip to content

Commit bd01d9f

Browse files
Add left padding support to tokenizers. (comfyanonymous#10753)
1 parent 443056c commit bd01d9f

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

comfy/sd1_clip.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
460460
return embed_out
461461

462462
class SDTokenizer:
463-
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
463+
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
464464
if tokenizer_path is None:
465465
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
466466
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
467467
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
468468
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
469469
self.end_token = None
470470
self.min_padding = min_padding
471+
self.pad_left = pad_left
471472

472473
empty = self.tokenizer('')["input_ids"]
473474
self.tokenizer_adds_end_token = has_end_token
@@ -522,6 +523,12 @@ def _try_get_embedding(self, embedding_name:str):
522523
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
523524
return (embed, leftover)
524525

526+
def pad_tokens(self, tokens, amount):
527+
if self.pad_left:
528+
for i in range(amount):
529+
tokens.insert(0, (self.pad_token, 1.0, 0))
530+
else:
531+
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
525532

526533
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
527534
'''
@@ -600,7 +607,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
600607
if self.end_token is not None:
601608
batch.append((self.end_token, 1.0, 0))
602609
if self.pad_to_max_length:
603-
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
610+
self.pad_tokens(batch, remaining_length)
604611
#start new batch
605612
batch = []
606613
if self.start_token is not None:
@@ -614,11 +621,11 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
614621
if self.end_token is not None:
615622
batch.append((self.end_token, 1.0, 0))
616623
if min_padding is not None:
617-
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
624+
self.pad_tokens(batch, min_padding)
618625
if self.pad_to_max_length and len(batch) < self.max_length:
619-
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
626+
self.pad_tokens(batch, self.max_length - len(batch))
620627
if min_length is not None and len(batch) < min_length:
621-
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
628+
self.pad_tokens(batch, min_length - len(batch))
622629

623630
if not return_word_ids:
624631
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

0 commit comments

Comments
 (0)