Skip to content

Commit ad44116

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[Take 2] Merge all_gather with all_gather_reduced, psum_scatter with unreduced_psum_scatter and psum with unreduced_psum.
Here are the changes: all_gather signature gets a to argument. all_gather(x, axis_name, tiled=True, to=...). The allowed values are varying and reduced. to defaults to varying to preserve the current behavior but you can get AGR by specifying to='reduced' psum_scatter will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to unreduced_psum_scatter_p. If the input is varying, it will dispatch to reduce_scatter_p psum will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to unreduced_psum_p. If the input is varying, it will dispatch to psum_invariant_p Reverts 5b9cfa3 PiperOrigin-RevId: 839418947
1 parent 5b9cfa3 commit ad44116

File tree

3 files changed

+71
-22
lines changed

3 files changed

+71
-22
lines changed

jax/_src/lax/parallel.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def psum(x, axis_name, *, axis_index_groups=None):
121121
[20 22 24 26]
122122
[20 22 24 26]]
123123
"""
124+
axes = ((axis_name,) if not isinstance(axis_name, (tuple, list)) else
125+
tuple(axis_name))
126+
if not axes:
127+
return x
128+
def bind(leaf):
129+
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum')
130+
if from_ == 'unreduced':
131+
if axis_index_groups is not None:
132+
raise NotImplementedError
133+
return unreduced_psum(leaf, axes)
134+
else:
135+
return _psum(leaf, axes, axis_index_groups=axis_index_groups)
136+
return tree_util.tree_map(bind, x)
137+
138+
def _psum(x, axis_name, *, axis_index_groups):
124139
if not isinstance(axis_name, (tuple, list)):
125140
axis_name = (axis_name,)
126141
if not axis_name:
@@ -1611,7 +1626,8 @@ def insert_collective_pvary(axis_name, x):
16111626
x = pvary(x, tuple(n for n in names_union if n not in aval.vma))
16121627
return x
16131628

1614-
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
1629+
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False,
1630+
to: str = 'varying'):
16151631
"""Gather values of x across all replicas.
16161632
16171633
If ``x`` is a pytree then the result is equivalent to mapping this function to
@@ -1675,6 +1691,22 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
16751691
[[12 13 14 15]
16761692
[ 4 5 6 7]]]
16771693
"""
1694+
_allowed_ag_to = {'varying', 'reduced'}
1695+
if to not in _allowed_ag_to:
1696+
raise ValueError(
1697+
"Got unexpected `to` value for `jax.lax.all_gather`. Allowed `to`"
1698+
f" values are: {_allowed_ag_to}")
1699+
if to == 'varying':
1700+
return _all_gather(x, axis_name, axis_index_groups=axis_index_groups,
1701+
axis=axis, tiled=tiled)
1702+
else:
1703+
assert to == 'reduced'
1704+
if axis_index_groups is not None:
1705+
raise NotImplementedError
1706+
return all_gather_reduced(x, axis_name, axis=axis, tiled=tiled)
1707+
1708+
1709+
def _all_gather(x, axis_name, *, axis_index_groups, axis, tiled):
16781710
if not isinstance(axis_name, tuple):
16791711
axis_name = (axis_name,)
16801712
if not axis_name:
@@ -2131,6 +2163,22 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
21312163
[12 14]
21322164
[16 18]]
21332165
"""
2166+
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
2167+
if not axes:
2168+
return x
2169+
def bind(leaf):
2170+
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum_scatter')
2171+
if from_ == 'unreduced':
2172+
if axis_index_groups is not None:
2173+
raise NotImplementedError
2174+
return unreduced_psum_scatter(
2175+
leaf, axes, scatter_dimension=scatter_dimension, tiled=tiled)
2176+
else:
2177+
return _psum_scatter(leaf, axes, scatter_dimension=scatter_dimension,
2178+
axis_index_groups=axis_index_groups, tiled=tiled)
2179+
return tree_util.tree_map(bind, x)
2180+
2181+
def _psum_scatter(x, axis_name, *, scatter_dimension, axis_index_groups, tiled):
21342182
if not isinstance(axis_name, tuple):
21352183
axis_name = (axis_name,)
21362184
if not axis_name:
@@ -2744,7 +2792,7 @@ def _reduced_vary_cast_batcher(vals_in, dims_in, *, axes):
27442792

27452793
################################## pcast #############################
27462794

2747-
def _get_from(aval, axes: tuple[AxisName, ...]) -> str:
2795+
def _get_from(aval, axes: tuple[AxisName, ...], name) -> str:
27482796
vma = aval.vma
27492797
unreduced = aval.sharding.spec.unreduced
27502798
reduced = aval.sharding.spec.reduced
@@ -2765,7 +2813,7 @@ def _get_from(aval, axes: tuple[AxisName, ...]) -> str:
27652813

27662814
if len(out) > 1:
27672815
raise ValueError(
2768-
"`jax.lax.pcast` can only accept axis_name which corresponds to one of"
2816+
f"{name} can only accept axis_name which corresponds to one of"
27692817
" varying, unreduced, reduced or invarying state of the input. Got"
27702818
f" input type: {aval}, axes: {axes} and input state: {out}")
27712819
o, = out
@@ -2779,18 +2827,22 @@ def _get_from(aval, axes: tuple[AxisName, ...]) -> str:
27792827
('reduced', 'varying'): core.reduced_vary_cast,
27802828
}
27812829

2782-
_allowed_to = {'unreduced', 'reduced', 'varying'}
2830+
_allowed_pcast_to = {'unreduced', 'reduced', 'varying'}
27832831

27842832
def pcast(x, axis_name, *, to: str):
27852833
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
27862834
if not axis_name:
27872835
return x
27882836

2789-
if to not in _allowed_to:
2837+
if to not in _allowed_pcast_to:
27902838
raise ValueError(
2791-
f'Got unexpected `to` value. Allowed `to` values are: {_allowed_to}')
2792-
from_ = _get_from(core.typeof(x), axes)
2793-
func = _pcast_funcs.get((from_.lower(), to.lower()), None)
2794-
if func is None:
2795-
raise ValueError(f"Unsupported pcast from={from_}, {to=}")
2796-
return func(x, axes)
2839+
"Got unexpected `to` value. Allowed `to` values are:"
2840+
f" {_allowed_pcast_to}")
2841+
2842+
def bind(leaf):
2843+
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.pcast')
2844+
func = _pcast_funcs.get((from_, to), None)
2845+
if func is None:
2846+
raise ValueError(f"Unsupported pcast from={from_}, {to=}")
2847+
return func(leaf, axes)
2848+
return tree_util.tree_map(bind, x)

jax/lax/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@
357357
)
358358
from jax._src.lax.parallel import (
359359
all_gather as all_gather,
360-
all_gather_reduced as all_gather_reduced,
361-
unreduced_psum_scatter as unreduced_psum_scatter,
362-
unreduced_psum as unreduced_psum,
363360
pcast as pcast,
364361
all_gather_p as all_gather_p,
365362
all_to_all as all_to_all,

tests/shard_map_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2598,7 +2598,7 @@ def f(x, y):
25982598
def ag(a):
25992599
self.assertEqual(a.aval.vma, {'seq'})
26002600
self.assertEqual(a.aval.sharding.spec.unreduced, frozenset())
2601-
out = lax.all_gather_reduced(a, axis_name='seq', tiled=True)
2601+
out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced')
26022602
self.assertEqual(out.aval.vma, frozenset())
26032603
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
26042604
self.assertEqual(out.aval.sharding.spec.reduced, {'seq'})
@@ -2618,7 +2618,7 @@ def f_bwd(res, g):
26182618
def rs(a):
26192619
self.assertEqual(a.aval.vma, frozenset())
26202620
self.assertEqual(a.aval.sharding.spec.unreduced, {'data', 'seq'})
2621-
out = lax.unreduced_psum_scatter(a, axis_name='seq', tiled=True)
2621+
out = lax.psum_scatter(a, axis_name='seq', tiled=True)
26222622
self.assertEqual(out.aval.vma, {'seq'})
26232623
self.assertEqual(out.aval.sharding.spec.unreduced, {'data'})
26242624
return out
@@ -2628,7 +2628,7 @@ def rs(a):
26282628
def ar(a):
26292629
self.assertEqual(a.aval.vma, {'seq'})
26302630
self.assertEqual(a.aval.sharding.spec.unreduced, {'data'})
2631-
out = lax.unreduced_psum(a, axis_name='data')
2631+
out = lax.psum(a, axis_name='data')
26322632
self.assertEqual(out.aval.vma, {'seq'})
26332633
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
26342634
return out
@@ -2684,7 +2684,7 @@ def ag(a):
26842684
self.assertEqual(a.aval.vma, {'seq'})
26852685
self.assertEqual(a.aval.sharding.spec.unreduced, frozenset())
26862686
self.assertEqual(a.aval.sharding.spec.reduced, {'data'})
2687-
out = lax.all_gather_reduced(a, axis_name='seq', tiled=True)
2687+
out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced')
26882688
self.assertEqual(out.aval.vma, frozenset())
26892689
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
26902690
self.assertEqual(out.aval.sharding.spec.reduced, {'seq', 'data'})
@@ -2727,7 +2727,7 @@ def test_unreduced_psum_fwd_preduced_bwd(self, mesh):
27272727
def ar(x):
27282728
self.assertEqual(x.aval.vma, frozenset())
27292729
self.assertEqual(x.aval.sharding.spec.unreduced, {'x'})
2730-
out = jax.lax.unreduced_psum(x, 'x')
2730+
out = jax.lax.psum(x, 'x')
27312731
self.assertEqual(out.aval.vma, frozenset())
27322732
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
27332733
return out
@@ -2797,7 +2797,7 @@ def test_all_gather_reduced_fwd_unreduced_psum_scatter_bwd(self, mesh):
27972797
def ag(a):
27982798
self.assertEqual(a.aval.vma, {'seq'})
27992799
self.assertEqual(a.aval.sharding.spec.unreduced, frozenset())
2800-
out = lax.all_gather_reduced(a, axis_name='seq', tiled=True)
2800+
out = lax.all_gather(a, axis_name='seq', tiled=True, to='reduced')
28012801
self.assertEqual(out.aval.vma, frozenset())
28022802
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
28032803
self.assertEqual(out.aval.sharding.spec.reduced, {'seq'})
@@ -2834,7 +2834,7 @@ def test_unreduced_psum_scatter_fwd_all_gather_reduced_bwd(self, mesh):
28342834
def rs(a):
28352835
self.assertEqual(a.aval.vma, frozenset())
28362836
self.assertEqual(a.aval.sharding.spec.unreduced, {'x'})
2837-
out = lax.unreduced_psum_scatter(a, axis_name='x', tiled=True)
2837+
out = lax.psum_scatter(a, axis_name='x', tiled=True)
28382838
self.assertEqual(out.aval.vma, {'x'})
28392839
self.assertEqual(out.aval.sharding.spec.unreduced, frozenset())
28402840
return out
@@ -4469,7 +4469,7 @@ def f(x, y):
44694469
return jax.lax.pcast(a, ('x', 'y'), to='reduced')
44704470

44714471
with self.assertRaisesRegex(
4472-
ValueError, "`jax.lax.pcast` can only accept axis_name which"):
4472+
ValueError, "jax.lax.pcast can only accept axis_name which"):
44734473
f(arr1, arr2)
44744474

44754475
@parameterized.named_parameters(

0 commit comments

Comments
 (0)