-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Add Aiter Attention Backend Support on AMD GPUs #10511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,15 @@ | |
| logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") | ||
| exit(-1) | ||
|
|
||
| AITER_ATTENTION_IS_AVAILABLE = False | ||
| try: | ||
| import aiter | ||
| AITER_ATTENTION_IS_AVAILABLE = True | ||
| except ImportError: | ||
| if model_management.aiter_attention_enabled(): | ||
| logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter") | ||
| exit(-1) | ||
|
|
||
| REGISTERED_ATTENTION_FUNCTIONS = {} | ||
| def register_attention_function(name: str, func: Callable): | ||
| # avoid replacing existing functions | ||
|
|
@@ -612,18 +621,87 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape | |
| except Exception as e: | ||
| logging.warning(f"Flash Attention failed, using default SDPA: {e}") | ||
| out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) | ||
|
|
||
| if not skip_output_reshape: | ||
| out = ( | ||
| out.transpose(1, 2).reshape(b, -1, heads * dim_head) | ||
| ) | ||
| return out | ||
|
|
||
|
|
||
| def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | ||
|
||
| dropout_p: float = 0.0, softmax_scale: Optional[float] = None, | ||
| causal: bool = False, window_size: tuple = (-1, -1), | ||
| bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, | ||
| deterministic: bool = False) -> torch.Tensor: | ||
| """Wrapper for aiter.flash_attn_func to handle its specific parameters""" | ||
| return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, | ||
| causal=causal, window_size=window_size, bias=bias, | ||
| alibi_slopes=alibi_slopes, deterministic=deterministic, | ||
| return_lse=False, return_attn_probs=False, | ||
| cu_seqlens_q=None, cu_seqlens_kv=None) | ||
|
|
||
| @wrap_attn | ||
| def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): | ||
| if skip_reshape: | ||
| b, _, _, dim_head = q.shape | ||
| else: | ||
| b, _, dim_head = q.shape | ||
| dim_head //= heads | ||
| # reshape to (batch, seqlen, nheads, headdim) for aiter | ||
| q, k, v = map( | ||
| lambda t: t.view(b, -1, heads, dim_head), | ||
| (q, k, v), | ||
| ) | ||
|
|
||
| if mask is not None: | ||
| # add a batch dimension if there isn't already one | ||
| if mask.ndim == 2: | ||
| mask = mask.unsqueeze(0) | ||
| # add a heads dimension if there isn't already one | ||
| if mask.ndim == 3: | ||
| mask = mask.unsqueeze(1) | ||
|
|
||
| try: | ||
| # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format | ||
| out = aiter_flash_attn_wrapper( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=0.0, | ||
| softmax_scale=None, | ||
| causal=False, | ||
| window_size=(-1, -1), | ||
| bias=mask, | ||
| alibi_slopes=None, | ||
| deterministic=False, | ||
| ) | ||
| except Exception as e: | ||
| logging.warning(f"Aiter Attention failed, using default SDPA: {e}") | ||
| # fallback needs (batch, nheads, seqlen, headdim) format | ||
| q_sdpa = q.transpose(1, 2) | ||
| k_sdpa = k.transpose(1, 2) | ||
| v_sdpa = v.transpose(1, 2) | ||
| out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False) | ||
| out = out.transpose(1, 2) | ||
|
|
||
| if skip_output_reshape: | ||
| # output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) | ||
| out = out.transpose(1, 2) | ||
| else: | ||
| # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) | ||
| out = out.reshape(b, -1, heads * dim_head) | ||
| return out | ||
|
|
||
|
|
||
| optimized_attention = attention_basic | ||
|
|
||
| if model_management.sage_attention_enabled(): | ||
| logging.info("Using sage attention") | ||
| optimized_attention = attention_sage | ||
| elif model_management.aiter_attention_enabled(): | ||
| logging.info("Using aiter attention") | ||
| optimized_attention = attention_aiter | ||
| elif model_management.xformers_enabled(): | ||
| logging.info("Using xformers attention") | ||
| optimized_attention = attention_xformers | ||
|
|
@@ -647,6 +725,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape | |
| # register core-supported attention functions | ||
| if SAGE_ATTENTION_IS_AVAILABLE: | ||
| register_attention_function("sage", attention_sage) | ||
| if AITER_ATTENTION_IS_AVAILABLE: | ||
| register_attention_function("aiter", attention_aiter) | ||
| if FLASH_ATTENTION_IS_AVAILABLE: | ||
| register_attention_function("flash", attention_flash) | ||
| if model_management.xformers_enabled(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"pip install aiter" doesn't install the right aiter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have change it. Because AITER doesn't provide whl package now, so we just tell user to refer to AITER repo, hope users can install it smoothly!