Commit ad44116
[Take 2] Merge all_gather with all_gather_reduced, psum_scatter with unreduced_psum_scatter and psum with unreduced_psum.
Here are the changes:
all_gather signature gets a to argument. all_gather(x, axis_name, tiled=True, to=...). The allowed values are varying and reduced. to defaults to varying to preserve the current behavior but you can get AGR by specifying to='reduced'
psum_scatter will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to unreduced_psum_scatter_p. If the input is varying, it will dispatch to reduce_scatter_p
psum will infer the input state from the type. If the input is unreduced over the axis_name, then we will dispatch to unreduced_psum_p. If the input is varying, it will dispatch to psum_invariant_p
Reverts 5b9cfa3
PiperOrigin-RevId: 8394189471 parent 5b9cfa3 commit ad44116
File tree
3 files changed
+71
-22
lines changed- jax
- _src/lax
- lax
- tests
3 files changed
+71
-22
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
121 | 121 | | |
122 | 122 | | |
123 | 123 | | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
124 | 139 | | |
125 | 140 | | |
126 | 141 | | |
| |||
1611 | 1626 | | |
1612 | 1627 | | |
1613 | 1628 | | |
1614 | | - | |
| 1629 | + | |
| 1630 | + | |
1615 | 1631 | | |
1616 | 1632 | | |
1617 | 1633 | | |
| |||
1675 | 1691 | | |
1676 | 1692 | | |
1677 | 1693 | | |
| 1694 | + | |
| 1695 | + | |
| 1696 | + | |
| 1697 | + | |
| 1698 | + | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
1678 | 1710 | | |
1679 | 1711 | | |
1680 | 1712 | | |
| |||
2131 | 2163 | | |
2132 | 2164 | | |
2133 | 2165 | | |
| 2166 | + | |
| 2167 | + | |
| 2168 | + | |
| 2169 | + | |
| 2170 | + | |
| 2171 | + | |
| 2172 | + | |
| 2173 | + | |
| 2174 | + | |
| 2175 | + | |
| 2176 | + | |
| 2177 | + | |
| 2178 | + | |
| 2179 | + | |
| 2180 | + | |
| 2181 | + | |
2134 | 2182 | | |
2135 | 2183 | | |
2136 | 2184 | | |
| |||
2744 | 2792 | | |
2745 | 2793 | | |
2746 | 2794 | | |
2747 | | - | |
| 2795 | + | |
2748 | 2796 | | |
2749 | 2797 | | |
2750 | 2798 | | |
| |||
2765 | 2813 | | |
2766 | 2814 | | |
2767 | 2815 | | |
2768 | | - | |
| 2816 | + | |
2769 | 2817 | | |
2770 | 2818 | | |
2771 | 2819 | | |
| |||
2779 | 2827 | | |
2780 | 2828 | | |
2781 | 2829 | | |
2782 | | - | |
| 2830 | + | |
2783 | 2831 | | |
2784 | 2832 | | |
2785 | 2833 | | |
2786 | 2834 | | |
2787 | 2835 | | |
2788 | 2836 | | |
2789 | | - | |
| 2837 | + | |
2790 | 2838 | | |
2791 | | - | |
2792 | | - | |
2793 | | - | |
2794 | | - | |
2795 | | - | |
2796 | | - | |
| 2839 | + | |
| 2840 | + | |
| 2841 | + | |
| 2842 | + | |
| 2843 | + | |
| 2844 | + | |
| 2845 | + | |
| 2846 | + | |
| 2847 | + | |
| 2848 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
357 | 357 | | |
358 | 358 | | |
359 | 359 | | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | 360 | | |
364 | 361 | | |
365 | 362 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2598 | 2598 | | |
2599 | 2599 | | |
2600 | 2600 | | |
2601 | | - | |
| 2601 | + | |
2602 | 2602 | | |
2603 | 2603 | | |
2604 | 2604 | | |
| |||
2618 | 2618 | | |
2619 | 2619 | | |
2620 | 2620 | | |
2621 | | - | |
| 2621 | + | |
2622 | 2622 | | |
2623 | 2623 | | |
2624 | 2624 | | |
| |||
2628 | 2628 | | |
2629 | 2629 | | |
2630 | 2630 | | |
2631 | | - | |
| 2631 | + | |
2632 | 2632 | | |
2633 | 2633 | | |
2634 | 2634 | | |
| |||
2684 | 2684 | | |
2685 | 2685 | | |
2686 | 2686 | | |
2687 | | - | |
| 2687 | + | |
2688 | 2688 | | |
2689 | 2689 | | |
2690 | 2690 | | |
| |||
2727 | 2727 | | |
2728 | 2728 | | |
2729 | 2729 | | |
2730 | | - | |
| 2730 | + | |
2731 | 2731 | | |
2732 | 2732 | | |
2733 | 2733 | | |
| |||
2797 | 2797 | | |
2798 | 2798 | | |
2799 | 2799 | | |
2800 | | - | |
| 2800 | + | |
2801 | 2801 | | |
2802 | 2802 | | |
2803 | 2803 | | |
| |||
2834 | 2834 | | |
2835 | 2835 | | |
2836 | 2836 | | |
2837 | | - | |
| 2837 | + | |
2838 | 2838 | | |
2839 | 2839 | | |
2840 | 2840 | | |
| |||
4469 | 4469 | | |
4470 | 4470 | | |
4471 | 4471 | | |
4472 | | - | |
| 4472 | + | |
4473 | 4473 | | |
4474 | 4474 | | |
4475 | 4475 | | |
| |||
0 commit comments