diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 204fc048d466..e98c7d6d8a66 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -178,6 +178,15 @@ def process_in(self, latent): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 23150a7125a2..2472ab79c214 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -80,14 +80,14 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -98,11 +98,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -142,27 +152,44 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): @@ -246,6 +273,8 @@ def __init__( mlp_ratio: float = 4.0, qk_scale: float = None, modulation=True, + mlp_silu_act=False, + bias=True, dtype=None, device=None, operations=None @@ -257,17 +286,24 @@ def __init__( self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") if modulation: self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) else: @@ -279,7 +315,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation else: mod = vec - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) del qkv @@ -298,11 +334,11 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f2024d7..1a24e6d95f5f 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation ) @dataclass @@ -33,6 +34,11 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 class Flux(nn.Module): @@ -58,13 +64,17 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -73,6 +83,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -81,13 +94,30 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -103,9 +133,6 @@ def forward_orig( attn_mask: Tensor = None, ) -> Tensor: - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -118,9 +145,17 @@ def forward_orig( if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) @@ -177,6 +212,9 @@ def block_wrap(args): img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): @@ -207,7 +245,7 @@ def block_wrap(args): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -234,10 +272,10 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={} h_offset += rope_options.get("shift_y", 0.0) w_offset += rope_options.get("shift_x", 0.0) - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -259,10 +297,10 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None h = 0 w = 0 index = 0 - ref_latents_method = kwargs.get("ref_latents_method", "offset") + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: if ref_latents_method == "index": - index += 1 + index += self.params.ref_index_scale h_offset = 0 w_offset = 0 elif ref_latents_method == "uxo": @@ -286,7 +324,11 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.axes_dim) == 4: # Flux 2 + txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 611d36a1bf94..4f50810dc7d6 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,8 @@ from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.modules.ema import LitEma import comfy.ops +from einops import rearrange +import comfy.model_management class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): @@ -179,6 +181,21 @@ def __init__(self, embed_dim: int, **kwargs): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params @@ -201,11 +218,36 @@ def encode( z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index e14b552c5053..cad79ecbdf4c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,12 +898,13 @@ def extra_conds(self, **kwargs): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -928,6 +929,16 @@ def extra_conds_shapes(self, **kwargs): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0131ca25a51e..b2ba1459dd4a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["vec_in_dim"] = None + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 diff --git a/comfy/sd.py b/comfy/sd.py index b6df0bd61d48..14dd8944c728 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -356,7 +356,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32: + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -382,6 +382,17 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.upscale_ratio = 4 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + if 'post_quant_conv.weight' in sd: self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) else: @@ -940,6 +951,8 @@ class TEModel(Enum): QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -972,6 +985,13 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -1086,6 +1106,10 @@ class EmptyClass: else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2e64b85e88fb..8fe8e63f66e0 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -741,6 +741,37 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1422,6 +1453,7 @@ def clip_target(self, state_dict={}): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] + models += [SVD_img2vid] diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d61ef66689b3..8dbbca16eefa 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,10 +1,13 @@ from comfy import sd1_clip import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip +import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast import torch import os +import json +import base64 class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -68,3 +71,105 @@ def __init__(self, device="cpu", dtype=None, model_options={}): model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + try: + from transformers.integrations.mistral import MistralConverter + except ModuleNotFoundError: + from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + +def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Flux2TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index feb44bbb0670..749ff581bc01 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -34,6 +34,28 @@ class Llama2Config: rope_scale = None final_norm: bool = True +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -465,6 +487,15 @@ def __init__(self, config_dict, dtype, device, operations): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index ce1b2e89fa12..d9c4bba81063 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -2,7 +2,10 @@ import comfy.utils from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import math +import nodes class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -30,6 +33,27 @@ def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput: encode = execute # TODO: remove +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) class FluxGuidance(io.ComfyNode): @classmethod @@ -154,6 +178,58 @@ def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput: append = execute # TODO: remove +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + class FluxExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -163,6 +239,8 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: FluxDisableGuidance, FluxKontextImageScale, FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, ] diff --git a/nodes.py b/nodes.py index f023ae3b687c..f4835c02e590 100644 --- a/nodes.py +++ b/nodes.py @@ -929,7 +929,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),