Skip to content

Commit 556a01b

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix breakage cause by list being passed to psum
PiperOrigin-RevId: 839400181
1 parent e704639 commit 556a01b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/lax/parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
121121
[20 22 24 26]
122122
[20 22 24 26]]
123123
"""
124-
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
124+
axes = ((axis_name,) if not isinstance(axis_name, (tuple, list)) else
125+
tuple(axis_name))
125126
if not axes:
126127
return x
127128
def bind(leaf):

0 commit comments

Comments
 (0)