diff --git a/diffsynth_engine/models/basic/video_sparse_attention.py b/diffsynth_engine/models/basic/video_sparse_attention.py index 331fca2..f7497c7 100644 --- a/diffsynth_engine/models/basic/video_sparse_attention.py +++ b/diffsynth_engine/models/basic/video_sparse_attention.py @@ -3,10 +3,15 @@ import functools from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE -from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size +from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size + +vsa_core = None if VIDEO_SPARSE_ATTN_AVAILABLE: - from vsa import video_sparse_attn as vsa_core + try: + from vsa import video_sparse_attn as vsa_core + except Exception: + vsa_core = None VSA_TILE_SIZE = (4, 4, 4) @@ -171,6 +176,12 @@ def video_sparse_attn( variable_block_sizes: torch.LongTensor, non_pad_index: torch.LongTensor, ): + if vsa_core is None: + raise RuntimeError( + "Video sparse attention (VSA) is not available. " + "Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed." + ) + q = tile(q, num_tiles, tile_partition_indices, non_pad_index) k = tile(k, num_tiles, tile_partition_indices, non_pad_index) v = tile(v, num_tiles, tile_partition_indices, non_pad_index) @@ -212,7 +223,8 @@ def distributed_video_sparse_attn( ): from yunchang.comm.all_to_all import SeqAllToAll4D - assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1" + ring_world_size = get_sp_ring_world_size() + assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1" sp_ulysses_group = get_sp_ulysses_group() q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx) diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index a7f0354..116effb 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -650,7 +650,7 @@ def has_any_key(*xs): dit_type = "wan2.2-i2v-a14b" elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16: dit_type = "wan2.2-t2v-a14b" - elif model_state_dict["patch_embedding.weight"].shape[1] == 48: + elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48: dit_type = "wan2.2-ti2v-5b" elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"): dit_type = "wan2.1-flf2v-14b" @@ -680,6 +680,30 @@ def has_any_key(*xs): if config.attn_params is None: config.attn_params = VideoSparseAttentionParams(sparsity=0.9) + def update_weights(self, state_dicts: WanStateDicts) -> None: + is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and + ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model)) + is_dual_model_pipeline = self.dit2 is not None + + if is_dual_model_state_dict != is_dual_model_pipeline: + raise ValueError( + f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model " + f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. " + f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)." + ) + + if is_dual_model_state_dict: + if "high_noise_model" in state_dicts.model: + self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype) + if "low_noise_model" in state_dicts.model: + self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype) + else: + self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype) + + self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype) + self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype) + self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype) + def compile(self): self.dit.compile_repeated_blocks() if self.dit2 is not None: diff --git a/diffsynth_engine/utils/parallel.py b/diffsynth_engine/utils/parallel.py index c9fe538..fc8e2c5 100644 --- a/diffsynth_engine/utils/parallel.py +++ b/diffsynth_engine/utils/parallel.py @@ -21,117 +21,33 @@ import diffsynth_engine.models.basic.attention as attention_ops from diffsynth_engine.utils.platform import empty_cache from diffsynth_engine.utils import logging +from diffsynth_engine.utils.process_group import ( + PROCESS_GROUP, + get_cfg_group, + get_cfg_world_size, + get_cfg_rank, + get_cfg_ranks, + get_sp_group, + get_sp_world_size, + get_sp_rank, + get_sp_ranks, + get_sp_ulysses_group, + get_sp_ulysses_world_size, + get_sp_ulysses_rank, + get_sp_ulysses_ranks, + get_sp_ring_group, + get_sp_ring_world_size, + get_sp_ring_rank, + get_sp_ring_ranks, + get_tp_group, + get_tp_world_size, + get_tp_rank, + get_tp_ranks, +) logger = logging.get_logger(__name__) -class Singleton: - _instance = None - - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) - return cls._instance - - -class ProcessGroupSingleton(Singleton): - def __init__(self): - self.CFG_GROUP: Optional[dist.ProcessGroup] = None - self.SP_GROUP: Optional[dist.ProcessGroup] = None - self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None - self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None - self.TP_GROUP: Optional[dist.ProcessGroup] = None - - self.CFG_RANKS: List[int] = [] - self.SP_RANKS: List[int] = [] - self.SP_ULYSSUES_RANKS: List[int] = [] - self.SP_RING_RANKS: List[int] = [] - self.TP_RANKS: List[int] = [] - - -PROCESS_GROUP = ProcessGroupSingleton() - - -def get_cfg_group(): - return PROCESS_GROUP.CFG_GROUP - - -def get_cfg_world_size(): - return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1 - - -def get_cfg_rank(): - return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0 - - -def get_cfg_ranks(): - return PROCESS_GROUP.CFG_RANKS - - -def get_sp_group(): - return PROCESS_GROUP.SP_GROUP - - -def get_sp_world_size(): - return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1 - - -def get_sp_rank(): - return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0 - - -def get_sp_ranks(): - return PROCESS_GROUP.SP_RANKS - - -def get_sp_ulysses_group(): - return PROCESS_GROUP.SP_ULYSSUES_GROUP - - -def get_sp_ulysses_world_size(): - return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1 - - -def get_sp_ulysses_rank(): - return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0 - - -def get_sp_ulysses_ranks(): - return PROCESS_GROUP.SP_ULYSSUES_RANKS - - -def get_sp_ring_group(): - return PROCESS_GROUP.SP_RING_GROUP - - -def get_sp_ring_world_size(): - return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1 - - -def get_sp_ring_rank(): - return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0 - - -def get_sp_ring_ranks(): - return PROCESS_GROUP.SP_RING_RANKS - - -def get_tp_group(): - return PROCESS_GROUP.TP_GROUP - - -def get_tp_world_size(): - return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1 - - -def get_tp_rank(): - return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0 - - -def get_tp_ranks(): - return PROCESS_GROUP.TP_RANKS - - def init_parallel_pgs( cfg_degree: int = 1, sp_ulysses_degree: int = 1, diff --git a/diffsynth_engine/utils/process_group.py b/diffsynth_engine/utils/process_group.py new file mode 100644 index 0000000..366eed8 --- /dev/null +++ b/diffsynth_engine/utils/process_group.py @@ -0,0 +1,149 @@ +""" +Process group management for distributed training. + +This module provides singleton-based process group management for distributed training, +including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism. +""" + +import torch.distributed as dist +from typing import Optional, List + + +class Singleton: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) + return cls._instance + + +class ProcessGroupSingleton(Singleton): + def __init__(self): + if not hasattr(self, 'initialized'): + self.CFG_GROUP: Optional[dist.ProcessGroup] = None + self.SP_GROUP: Optional[dist.ProcessGroup] = None + self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None + self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None + self.TP_GROUP: Optional[dist.ProcessGroup] = None + + self.CFG_RANKS: List[int] = [] + self.SP_RANKS: List[int] = [] + self.SP_ULYSSUES_RANKS: List[int] = [] + self.SP_RING_RANKS: List[int] = [] + self.TP_RANKS: List[int] = [] + + self.initialized = True + + +PROCESS_GROUP = ProcessGroupSingleton() + + +# CFG parallel group functions +def get_cfg_group(): + return PROCESS_GROUP.CFG_GROUP + + +def get_cfg_world_size(): + return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1 + + +def get_cfg_rank(): + return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0 + + +def get_cfg_ranks(): + return PROCESS_GROUP.CFG_RANKS + + +# Sequence parallel group functions +def get_sp_group(): + return PROCESS_GROUP.SP_GROUP + + +def get_sp_world_size(): + return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1 + + +def get_sp_rank(): + return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0 + + +def get_sp_ranks(): + return PROCESS_GROUP.SP_RANKS + + +# Sequence parallel Ulysses group functions +def get_sp_ulysses_group(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP + + +def get_sp_ulysses_world_size(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1 + + +def get_sp_ulysses_rank(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0 + + +def get_sp_ulysses_ranks(): + return PROCESS_GROUP.SP_ULYSSUES_RANKS + + +# Sequence parallel Ring group functions +def get_sp_ring_group(): + return PROCESS_GROUP.SP_RING_GROUP + + +def get_sp_ring_world_size(): + return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1 + + +def get_sp_ring_rank(): + return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0 + + +def get_sp_ring_ranks(): + return PROCESS_GROUP.SP_RING_RANKS + + +# Tensor parallel group functions +def get_tp_group(): + return PROCESS_GROUP.TP_GROUP + + +def get_tp_world_size(): + return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1 + + +def get_tp_rank(): + return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0 + + +def get_tp_ranks(): + return PROCESS_GROUP.TP_RANKS + + +__all__ = [ + "PROCESS_GROUP", + "get_cfg_group", + "get_cfg_world_size", + "get_cfg_rank", + "get_cfg_ranks", + "get_sp_group", + "get_sp_world_size", + "get_sp_rank", + "get_sp_ranks", + "get_sp_ulysses_group", + "get_sp_ulysses_world_size", + "get_sp_ulysses_rank", + "get_sp_ulysses_ranks", + "get_sp_ring_group", + "get_sp_ring_world_size", + "get_sp_ring_rank", + "get_sp_ring_ranks", + "get_tp_group", + "get_tp_world_size", + "get_tp_rank", + "get_tp_ranks", +]