Skip to content

Commit 4fad55b

Browse files
Merge pull request #33504 from mingxu1067:mingh/allow_256H_SDPA
PiperOrigin-RevId: 839816495
2 parents 8f48ac8 + 3feb474 commit 4fad55b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ def check_is_flash_attention(
379379
else:
380380
# bf16/fp16 attention conditions
381381
# Check the head dim.
382-
is_on_hopper = is_cuda_compute_capability_equal("9.0")
383-
H_max = 256 if is_on_hopper else 128
382+
is_hopper_or_later = check_compute_capability("9.0")
383+
H_max = 256 if is_hopper_or_later else 128
384384
# check if multi-head latent attention is needed
385385
is_mla = qH != vH
386386
if not (qH <= H_max and qH % 8 == 0):

0 commit comments

Comments
 (0)