Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/op_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_num_ways_dim_sharded(
return [], 1
if hlo_sharding.is_unreduced():
return [], 1
partitions = hlo_sharding.tile_assignment_dimensions()
partitions = hlo_sharding.dimensions()
subgroup_types = hlo_sharding.subgroup_types()

if subgroup_types == [xc.OpSharding.Type.REPLICATED]:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def parse_flatten_op_sharding(
mesh.shape, hlo_sharding.tile_assignment_devices()
)
mesh_axis = iter(mesh_axis_order)
shape = hlo_sharding.tile_assignment_dimensions()
shape = hlo_sharding.dimensions()
partitions = []
for dim_size in shape:
dim_partitions = []
Expand Down
7 changes: 2 additions & 5 deletions jaxlib/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1336,9 +1336,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
nb::lock_self())
.def(
"num_dimensions",
[](const xla::HloSharding& self) {
return self.tile_assignment().num_dimensions();
},
[](const xla::HloSharding& self) { return self.num_dimensions(); },
nb::lock_self())
.def("is_tile_assignment_iota",
[](const xla::HloSharding& self) {
Expand All @@ -1347,8 +1345,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
.def(
"tile_assignment_dimensions",
[](const xla::HloSharding& self) {
absl::Span<int64_t const> span =
self.tile_assignment().dimensions();
absl::Span<int64_t const> span = self.dimensions();
CHECK(span.data());
return span;
},
Expand Down
2 changes: 1 addition & 1 deletion tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def test_mesh_pspec_sharding_interface(self):
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
allow_object_dtype=True)
self.assertTrue(hlo_sharding.is_tiled())
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
self.assertListEqual(hlo_sharding.dimensions(), [2, 4])
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
[0, 2, 4, 6, 1, 3, 5, 7])

Expand Down
15 changes: 7 additions & 8 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,9 @@ def testVMap(self):
self.assertAllClose(z, x[jnp.newaxis] + y)
self.assertAllClose(w, x)
self.assertEqual(
z.sharding._to_xla_hlo_sharding(z.ndim).tile_assignment_dimensions(),
[1, 2])
self.assertEqual(
w.sharding._to_xla_hlo_sharding(w.ndim).tile_assignment_dimensions(), [2])
z.sharding._to_xla_hlo_sharding(z.ndim).dimensions(), [1, 2]
)
self.assertEqual(w.sharding._to_xla_hlo_sharding(w.ndim).dimensions(), [2])

@jtu.with_mesh([('x', 2)])
def testVMapShardingConstraint(self):
Expand All @@ -765,7 +764,7 @@ def testVMapShardingConstraint(self):
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim)
self.assertTrue(op.is_tiled())
self.assertListEqual(op.tile_assignment_dimensions(), [1, 2])
self.assertListEqual(op.dimensions(), [1, 2])
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
self.assertFalse(op_shardings.is_hlo_sharding_replicated(op))

Expand All @@ -785,7 +784,7 @@ def testVMapShardingConstraintWithSpmdAxis(self):
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim)
self.assertTrue(op.is_tiled())
self.assertListEqual(op.tile_assignment_dimensions(), [2, 1])
self.assertListEqual(op.dimensions(), [2, 1])
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
self.assertFalse(op_shardings.is_hlo_sharding_replicated(op))

Expand Down Expand Up @@ -10077,7 +10076,7 @@ def test_op_sharding_equality_and_hash_equality(self):
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
self.assertEqual(hs1.num_devices(), 4)
self.assertEqual(hs1.num_dimensions(), 2)
self.assertEqual(hs1.tile_assignment_dimensions(), [2, 2])
self.assertEqual(hs1.dimensions(), [2, 2])
self.assertEqual(hs1.tile_assignment_devices(), [0, 1, 2, 3])
self.assertTrue(hs1.is_tiled())
self.assertFalse(hs1.replicate_on_last_tile_dim())
Expand Down Expand Up @@ -10329,7 +10328,7 @@ def test_hlo_sharding_manual_replicated(self):
self.assertFalse(hs3.is_manual())
self.assertFalse(hs3.is_replicated())
self.assertEqual(hs3.num_dimensions(), 2)
self.assertEqual(hs3.tile_assignment_dimensions(), [3, 3])
self.assertEqual(hs3.dimensions(), [3, 3])
self.assertEqual(hs3.num_devices(), 9)
self.assertEqual(hs3.tile_assignment_devices(), list(range(0, 9)))
self.assertEqual(
Expand Down
Loading