Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions diffsynth_engine/models/basic/video_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import functools

from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size


vsa_core = None
if VIDEO_SPARSE_ATTN_AVAILABLE:
from vsa import video_sparse_attn as vsa_core
try:
from vsa import video_sparse_attn as vsa_core
except Exception:
vsa_core = None

VSA_TILE_SIZE = (4, 4, 4)

Expand Down Expand Up @@ -171,6 +176,12 @@ def video_sparse_attn(
variable_block_sizes: torch.LongTensor,
non_pad_index: torch.LongTensor,
):
if vsa_core is None:
raise RuntimeError(
"Video sparse attention (VSA) is not available. "
"Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
)

q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
Expand Down Expand Up @@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
):
from yunchang.comm.all_to_all import SeqAllToAll4D

assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
ring_world_size = get_sp_ring_world_size()
assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
sp_ulysses_group = get_sp_ulysses_group()

q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
Expand Down
26 changes: 25 additions & 1 deletion diffsynth_engine/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def has_any_key(*xs):
dit_type = "wan2.2-i2v-a14b"
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
dit_type = "wan2.2-t2v-a14b"
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
dit_type = "wan2.2-ti2v-5b"
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
dit_type = "wan2.1-flf2v-14b"
Expand Down Expand Up @@ -680,6 +680,30 @@ def has_any_key(*xs):
if config.attn_params is None:
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)

def update_weights(self, state_dicts: WanStateDicts) -> None:
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
is_dual_model_pipeline = self.dit2 is not None

if is_dual_model_state_dict != is_dual_model_pipeline:
raise ValueError(
f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
)

if is_dual_model_state_dict:
if "high_noise_model" in state_dicts.model:
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
if "low_noise_model" in state_dicts.model:
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
else:
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)

self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)

def compile(self):
self.dit.compile_repeated_blocks()
if self.dit2 is not None:
Expand Down
130 changes: 23 additions & 107 deletions diffsynth_engine/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,117 +21,33 @@
import diffsynth_engine.models.basic.attention as attention_ops
from diffsynth_engine.utils.platform import empty_cache
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.process_group import (
PROCESS_GROUP,
get_cfg_group,
get_cfg_world_size,
get_cfg_rank,
get_cfg_ranks,
get_sp_group,
get_sp_world_size,
get_sp_rank,
get_sp_ranks,
get_sp_ulysses_group,
get_sp_ulysses_world_size,
get_sp_ulysses_rank,
get_sp_ulysses_ranks,
get_sp_ring_group,
get_sp_ring_world_size,
get_sp_ring_rank,
get_sp_ring_ranks,
get_tp_group,
get_tp_world_size,
get_tp_rank,
get_tp_ranks,
)

logger = logging.get_logger(__name__)


class Singleton:
_instance = None

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
return cls._instance


class ProcessGroupSingleton(Singleton):
def __init__(self):
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
self.SP_GROUP: Optional[dist.ProcessGroup] = None
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
self.TP_GROUP: Optional[dist.ProcessGroup] = None

self.CFG_RANKS: List[int] = []
self.SP_RANKS: List[int] = []
self.SP_ULYSSUES_RANKS: List[int] = []
self.SP_RING_RANKS: List[int] = []
self.TP_RANKS: List[int] = []


PROCESS_GROUP = ProcessGroupSingleton()


def get_cfg_group():
return PROCESS_GROUP.CFG_GROUP


def get_cfg_world_size():
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1


def get_cfg_rank():
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0


def get_cfg_ranks():
return PROCESS_GROUP.CFG_RANKS


def get_sp_group():
return PROCESS_GROUP.SP_GROUP


def get_sp_world_size():
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1


def get_sp_rank():
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0


def get_sp_ranks():
return PROCESS_GROUP.SP_RANKS


def get_sp_ulysses_group():
return PROCESS_GROUP.SP_ULYSSUES_GROUP


def get_sp_ulysses_world_size():
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1


def get_sp_ulysses_rank():
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0


def get_sp_ulysses_ranks():
return PROCESS_GROUP.SP_ULYSSUES_RANKS


def get_sp_ring_group():
return PROCESS_GROUP.SP_RING_GROUP


def get_sp_ring_world_size():
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1


def get_sp_ring_rank():
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0


def get_sp_ring_ranks():
return PROCESS_GROUP.SP_RING_RANKS


def get_tp_group():
return PROCESS_GROUP.TP_GROUP


def get_tp_world_size():
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1


def get_tp_rank():
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0


def get_tp_ranks():
return PROCESS_GROUP.TP_RANKS


def init_parallel_pgs(
cfg_degree: int = 1,
sp_ulysses_degree: int = 1,
Expand Down
149 changes: 149 additions & 0 deletions diffsynth_engine/utils/process_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Process group management for distributed training.

This module provides singleton-based process group management for distributed training,
including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
"""

import torch.distributed as dist
from typing import Optional, List


class Singleton:
_instance = None

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
return cls._instance


class ProcessGroupSingleton(Singleton):
def __init__(self):
if not hasattr(self, 'initialized'):
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
self.SP_GROUP: Optional[dist.ProcessGroup] = None
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
self.TP_GROUP: Optional[dist.ProcessGroup] = None

self.CFG_RANKS: List[int] = []
self.SP_RANKS: List[int] = []
self.SP_ULYSSUES_RANKS: List[int] = []
self.SP_RING_RANKS: List[int] = []
self.TP_RANKS: List[int] = []

self.initialized = True


PROCESS_GROUP = ProcessGroupSingleton()


# CFG parallel group functions
def get_cfg_group():
return PROCESS_GROUP.CFG_GROUP


def get_cfg_world_size():
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1


def get_cfg_rank():
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0


def get_cfg_ranks():
return PROCESS_GROUP.CFG_RANKS


# Sequence parallel group functions
def get_sp_group():
return PROCESS_GROUP.SP_GROUP


def get_sp_world_size():
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1


def get_sp_rank():
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0


def get_sp_ranks():
return PROCESS_GROUP.SP_RANKS


# Sequence parallel Ulysses group functions
def get_sp_ulysses_group():
return PROCESS_GROUP.SP_ULYSSUES_GROUP


def get_sp_ulysses_world_size():
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1


def get_sp_ulysses_rank():
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0


def get_sp_ulysses_ranks():
return PROCESS_GROUP.SP_ULYSSUES_RANKS


# Sequence parallel Ring group functions
def get_sp_ring_group():
return PROCESS_GROUP.SP_RING_GROUP


def get_sp_ring_world_size():
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1


def get_sp_ring_rank():
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0


def get_sp_ring_ranks():
return PROCESS_GROUP.SP_RING_RANKS


# Tensor parallel group functions
def get_tp_group():
return PROCESS_GROUP.TP_GROUP


def get_tp_world_size():
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1


def get_tp_rank():
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0


def get_tp_ranks():
return PROCESS_GROUP.TP_RANKS


__all__ = [
"PROCESS_GROUP",
"get_cfg_group",
"get_cfg_world_size",
"get_cfg_rank",
"get_cfg_ranks",
"get_sp_group",
"get_sp_world_size",
"get_sp_rank",
"get_sp_ranks",
"get_sp_ulysses_group",
"get_sp_ulysses_world_size",
"get_sp_ulysses_rank",
"get_sp_ulysses_ranks",
"get_sp_ring_group",
"get_sp_ring_world_size",
"get_sp_ring_rank",
"get_sp_ring_ranks",
"get_tp_group",
"get_tp_world_size",
"get_tp_rank",
"get_tp_ranks",
]