Skip to content

Commit 18b75f4

Browse files
Make the ScaleRope node work on Z Image and Lumina.
1 parent 5151cff commit 18b75f4

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

comfy/ldm/lumina/model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,23 @@ def patchify_and_embed(
517517
B, C, H, W = x.shape
518518
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
519519

520+
rope_options = transformer_options.get("rope_options", None)
521+
h_scale = 1.0
522+
w_scale = 1.0
523+
h_start = 0
524+
w_start = 0
525+
if rope_options is not None:
526+
h_scale = rope_options.get("scale_y", 1.0)
527+
w_scale = rope_options.get("scale_x", 1.0)
528+
529+
h_start = rope_options.get("shift_y", 0.0)
530+
w_start = rope_options.get("shift_x", 0.0)
531+
520532
H_tokens, W_tokens = H // pH, W // pW
521533
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
522534
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
523-
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
524-
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
535+
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
536+
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
525537

526538
if self.pad_tokens_multiple is not None:
527539
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple

0 commit comments

Comments
 (0)