Skip to content

Commit b94d394

Browse files
Support Z Image alibaba pai fun controlnets. (#11062)
These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it.
1 parent 277237c commit b94d394

File tree

3 files changed

+232
-12
lines changed

3 files changed

+232
-12
lines changed

comfy/ldm/lumina/controlnet.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
from torch import nn
3+
4+
from .model import JointTransformerBlock
5+
6+
class ZImageControlTransformerBlock(JointTransformerBlock):
7+
def __init__(
8+
self,
9+
layer_id: int,
10+
dim: int,
11+
n_heads: int,
12+
n_kv_heads: int,
13+
multiple_of: int,
14+
ffn_dim_multiplier: float,
15+
norm_eps: float,
16+
qk_norm: bool,
17+
modulation=True,
18+
block_id=0,
19+
operation_settings=None,
20+
):
21+
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
22+
self.block_id = block_id
23+
if block_id == 0:
24+
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
25+
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
26+
27+
def forward(self, c, x, **kwargs):
28+
if self.block_id == 0:
29+
c = self.before_proj(c) + x
30+
c = super().forward(c, **kwargs)
31+
c_skip = self.after_proj(c)
32+
return c_skip, c
33+
34+
class ZImage_Control(torch.nn.Module):
35+
def __init__(
36+
self,
37+
dim: int = 3840,
38+
n_heads: int = 30,
39+
n_kv_heads: int = 30,
40+
multiple_of: int = 256,
41+
ffn_dim_multiplier: float = (8.0 / 3.0),
42+
norm_eps: float = 1e-5,
43+
qk_norm: bool = True,
44+
dtype=None,
45+
device=None,
46+
operations=None,
47+
**kwargs
48+
):
49+
super().__init__()
50+
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
51+
52+
self.additional_in_dim = 0
53+
self.control_in_dim = 16
54+
n_refiner_layers = 2
55+
self.n_control_layers = 6
56+
self.control_layers = nn.ModuleList(
57+
[
58+
ZImageControlTransformerBlock(
59+
i,
60+
dim,
61+
n_heads,
62+
n_kv_heads,
63+
multiple_of,
64+
ffn_dim_multiplier,
65+
norm_eps,
66+
qk_norm,
67+
block_id=i,
68+
operation_settings=operation_settings,
69+
)
70+
for i in range(self.n_control_layers)
71+
]
72+
)
73+
74+
all_x_embedder = {}
75+
patch_size = 2
76+
f_patch_size = 1
77+
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
78+
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
79+
80+
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
81+
self.control_noise_refiner = nn.ModuleList(
82+
[
83+
JointTransformerBlock(
84+
layer_id,
85+
dim,
86+
n_heads,
87+
n_kv_heads,
88+
multiple_of,
89+
ffn_dim_multiplier,
90+
norm_eps,
91+
qk_norm,
92+
modulation=True,
93+
z_image_modulation=True,
94+
operation_settings=operation_settings,
95+
)
96+
for layer_id in range(n_refiner_layers)
97+
]
98+
)
99+
100+
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
101+
patch_size = 2
102+
f_patch_size = 1
103+
pH = pW = patch_size
104+
B, C, H, W = control_context.shape
105+
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
106+
107+
x_attn_mask = None
108+
for layer in self.control_noise_refiner:
109+
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
110+
return control_context
111+
112+
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
113+
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

comfy/ldm/lumina/model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwar
568568
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
569569

570570
# def forward(self, x, t, cap_feats, cap_mask):
571-
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
571+
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
572572
t = 1.0 - timesteps
573573
cap_feats = context
574574
cap_mask = attention_mask
@@ -585,16 +585,24 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwa
585585

586586
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
587587

588+
patches = transformer_options.get("patches", {})
588589
transformer_options = kwargs.get("transformer_options", {})
589590
x_is_tensor = isinstance(x, torch.Tensor)
590-
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
591-
freqs_cis = freqs_cis.to(x.device)
592-
593-
for layer in self.layers:
594-
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
595-
596-
x = self.final_layer(x, adaln_input)
597-
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
598-
599-
return -x
591+
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
592+
freqs_cis = freqs_cis.to(img.device)
593+
594+
for i, layer in enumerate(self.layers):
595+
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
596+
if "double_block" in patches:
597+
for p in patches["double_block"]:
598+
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
599+
if "img" in out:
600+
img[:, cap_size[0]:] = out["img"]
601+
if "txt" in out:
602+
img[:, :cap_size[0]] = out["txt"]
603+
604+
img = self.final_layer(img, adaln_input)
605+
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
606+
607+
return -img
600608

comfy_extras/nodes_model_patch.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import comfy.model_management
77
import comfy.ldm.common_dit
88
import comfy.latent_formats
9+
import comfy.ldm.lumina.controlnet
910

1011

1112
class BlockWiseControlBlock(torch.nn.Module):
@@ -189,6 +190,35 @@ def _process_layer_features(
189190

190191
return embedding
191192

193+
def z_image_convert(sd):
194+
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
195+
".attention.norm_k.weight": ".attention.k_norm.weight",
196+
".attention.norm_q.weight": ".attention.q_norm.weight",
197+
".attention.to_out.0.weight": ".attention.out.weight"
198+
}
199+
200+
out_sd = {}
201+
for k in sorted(sd.keys()):
202+
w = sd[k]
203+
204+
k_out = k
205+
if k_out.endswith(".attention.to_k.weight"):
206+
cc = [w]
207+
continue
208+
if k_out.endswith(".attention.to_q.weight"):
209+
cc = [w] + cc
210+
continue
211+
if k_out.endswith(".attention.to_v.weight"):
212+
cc = cc + [w]
213+
w = torch.cat(cc, dim=0)
214+
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
215+
216+
for r, rr in replace_keys.items():
217+
k_out = k_out.replace(r, rr)
218+
out_sd[k_out] = w
219+
220+
return out_sd
221+
192222
class ModelPatchLoader:
193223
@classmethod
194224
def INPUT_TYPES(s):
@@ -211,6 +241,9 @@ def load_model_patch(self, name):
211241
elif 'feature_embedder.mid_layer_norm.bias' in sd:
212242
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
213243
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
244+
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
245+
sd = z_image_convert(sd)
246+
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
214247

215248
model.load_state_dict(sd)
216249
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -263,6 +296,69 @@ def to(self, device_or_dtype):
263296
def models(self):
264297
return [self.model_patch]
265298

299+
class ZImageControlPatch:
300+
def __init__(self, model_patch, vae, image, strength):
301+
self.model_patch = model_patch
302+
self.vae = vae
303+
self.image = image
304+
self.strength = strength
305+
self.encoded_image = self.encode_latent_cond(image)
306+
self.encoded_image_size = (image.shape[1], image.shape[2])
307+
self.temp_data = None
308+
309+
def encode_latent_cond(self, image):
310+
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
311+
return latent_image
312+
313+
def __call__(self, kwargs):
314+
x = kwargs.get("x")
315+
img = kwargs.get("img")
316+
txt = kwargs.get("txt")
317+
pe = kwargs.get("pe")
318+
vec = kwargs.get("vec")
319+
block_index = kwargs.get("block_index")
320+
spacial_compression = self.vae.spacial_compression_encode()
321+
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
322+
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
323+
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
324+
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
325+
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
326+
comfy.model_management.load_models_gpu(loaded_models)
327+
328+
cnet_index = (block_index // 5)
329+
cnet_index_float = (block_index / 5)
330+
331+
kwargs.pop("img") # we do ops in place
332+
kwargs.pop("txt")
333+
334+
cnet_blocks = self.model_patch.model.n_control_layers
335+
if cnet_index_float > (cnet_blocks - 1):
336+
self.temp_data = None
337+
return kwargs
338+
339+
if self.temp_data is None or self.temp_data[0] > cnet_index:
340+
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
341+
342+
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
343+
next_layer = self.temp_data[0] + 1
344+
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
345+
346+
if cnet_index_float == self.temp_data[0]:
347+
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
348+
if cnet_blocks == self.temp_data[0] + 1:
349+
self.temp_data = None
350+
351+
return kwargs
352+
353+
def to(self, device_or_dtype):
354+
if isinstance(device_or_dtype, torch.device):
355+
self.encoded_image = self.encoded_image.to(device_or_dtype)
356+
self.temp_data = None
357+
return self
358+
359+
def models(self):
360+
return [self.model_patch]
361+
266362
class QwenImageDiffsynthControlnet:
267363
@classmethod
268364
def INPUT_TYPES(s):
@@ -289,7 +385,10 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
289385
mask = mask.unsqueeze(2)
290386
mask = 1.0 - mask
291387

292-
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
388+
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
389+
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
390+
else:
391+
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
293392
return (model_patched,)
294393

295394

0 commit comments

Comments
 (0)