We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 8f48ac8 + 3feb474 commit 4fad55bCopy full SHA for 4fad55b
jax/_src/cudnn/fused_attention_stablehlo.py
@@ -379,8 +379,8 @@ def check_is_flash_attention(
379
else:
380
# bf16/fp16 attention conditions
381
# Check the head dim.
382
- is_on_hopper = is_cuda_compute_capability_equal("9.0")
383
- H_max = 256 if is_on_hopper else 128
+ is_hopper_or_later = check_compute_capability("9.0")
+ H_max = 256 if is_hopper_or_later else 128
384
# check if multi-head latent attention is needed
385
is_mla = qH != vH
386
if not (qH <= H_max and qH % 8 == 0):
0 commit comments