Skip to content

Commit eddb6ae

Browse files
V-MoE Authorscopybara-github
authored andcommitted
google-internal visibility change.
PiperOrigin-RevId: 581928522
1 parent efcf732 commit eddb6ae

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

vmoe/projects/soft_moe/router.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
Results using this algorithm presented in the paper:
1818
- "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951).
1919
"""
20-
from typing import Dict, Optional, Tuple
20+
from typing import Dict, Mapping, Optional, Tuple
2121

2222
from absl import logging
2323
import flax.linen as nn
24+
from flax.linen import partitioning as nn_partitioning
2425
import jax
2526
import jax.numpy as jnp
2627
from vmoe import moe
@@ -51,9 +52,13 @@ class SoftRouter(nn.Module):
5152
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT
5253
partition_spec: Optional[jax.sharding.PartitionSpec] = None
5354
compute_similarity_metrics: bool = True
55+
partitioning_rules: Optional[Mapping[str, Tuple]] = None # pylint: disable=g-bare-generic
5456

5557
@nn.compact
5658
def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
59+
if self.partitioning_rules:
60+
inputs = nn_partitioning.with_sharding_constraint(
61+
inputs, self.partitioning_rules['inputs'])
5762
# Normalize inputs to have unit norm.
5863
dtype = self.dtype or inputs.dtype
5964
inputs = normalize(inputs.astype(dtype), axis=-1)
@@ -63,6 +68,10 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
6368
num_slots = moe.compute_capacity(
6469
group_size, self.num_experts, self.capacity_factor,
6570
ceil_or_round='round', multiple_of=1)
71+
logging.info(
72+
'With num_tokens=%d, num_experts=%d, num_slots=%d, '
73+
'capacity_factor=%f.', group_size, self.num_experts,
74+
num_slots, self.capacity_factor)
6675
else:
6776
num_slots = self.num_slots
6877
actual_capacity_factor = self.num_experts * num_slots / group_size
@@ -71,11 +80,22 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
7180
'%sWith num_tokens=%d, num_experts=%d and num_slots=%d, the actual '
7281
'capacity_factor is %f.', pre, group_size, self.num_experts,
7382
self.num_slots, actual_capacity_factor)
74-
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
83+
if self.partitioning_rules:
84+
mu = nn_partitioning.param_with_axes(
85+
'mu', self.mu_init, (dim, self.num_experts, num_slots),
86+
axes=self.partitioning_rules['mu'])
87+
else:
88+
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
7589
mu = normalize(mu.astype(dtype), axis=0)
7690
self.sow('intermediates', 'mu_unit', mu)
7791
# Scale inputs/mu before computing the logits.
78-
scale = self.param('scale', self.scale_init, ()).astype(dtype)
92+
if self.partitioning_rules:
93+
scale = nn_partitioning.param_with_axes(
94+
'scale', self.scale_init, (),
95+
axes=self.partitioning_rules['scale'])
96+
else:
97+
scale = self.param('scale', self.scale_init, ())
98+
scale = scale.astype(dtype)
7999
if inputs.size < mu.size:
80100
inputs = inputs * scale
81101
else:
@@ -89,6 +109,9 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
89109
# Compute router logits between pairs of items (m) and total slots (n * p),
90110
# independently on each group (g).
91111
logits = jnp.einsum('gmd,dnp->gmnp', inputs, mu, precision=self.precision)
112+
if self.partitioning_rules:
113+
logits = nn_partitioning.with_sharding_constraint(
114+
logits, self.partitioning_rules['logits'])
92115
logits = self.add_noise(logits)
93116
# Each slot takes a convex combination of the inputs.
94117
dispatch_weights = jax.nn.softmax(logits, axis=1)

0 commit comments

Comments
 (0)