Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class LatentPreviewMethod(enum.Enum):
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")

parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
Expand Down
80 changes: 80 additions & 0 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)

AITER_ATTENTION_IS_AVAILABLE = False
try:
import aiter
AITER_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.aiter_attention_enabled():
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"pip install aiter" doesn't install the right aiter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have change it. Because AITER doesn't provide whl package now, so we just tell user to refer to AITER repo, hope users can install it smoothly!

exit(-1)

REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
Expand Down Expand Up @@ -612,18 +621,87 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)

if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out


def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing this? Just put aiter.flash_attn_func in the attention_aiter function directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing this? Just put aiter.flash_attn_func in the attention_aiter function directly.

DONE

dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
causal=causal, window_size=window_size, bias=bias,
alibi_slopes=alibi_slopes, deterministic=deterministic,
return_lse=False, return_attn_probs=False,
cu_seqlens_q=None, cu_seqlens_kv=None)

@wrap_attn
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
# reshape to (batch, seqlen, nheads, headdim) for aiter
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)

if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)

try:
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
out = aiter_flash_attn_wrapper(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
bias=mask,
alibi_slopes=None,
deterministic=False,
)
except Exception as e:
logging.warning(f"Aiter Attention failed, using default SDPA: {e}")
# fallback needs (batch, nheads, seqlen, headdim) format
q_sdpa = q.transpose(1, 2)
k_sdpa = k.transpose(1, 2)
v_sdpa = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2)

if skip_output_reshape:
# output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
out = out.transpose(1, 2)
else:
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
out = out.reshape(b, -1, heads * dim_head)
return out


optimized_attention = attention_basic

if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
elif model_management.aiter_attention_enabled():
logging.info("Using aiter attention")
optimized_attention = attention_aiter
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
Expand All @@ -647,6 +725,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if AITER_ATTENTION_IS_AVAILABLE:
register_attention_function("aiter", attention_aiter)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention

def aiter_attention_enabled():
return args.use_aiter_attention

def flash_attention_enabled():
return args.use_flash_attention

Expand Down
Loading