1717Results using this algorithm presented in the paper:
1818 - "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951).
1919"""
20- from typing import Dict , Optional , Tuple
20+ from typing import Dict , Mapping , Optional , Tuple
2121
2222from absl import logging
2323import flax .linen as nn
24+ from flax .linen import partitioning as nn_partitioning
2425import jax
2526import jax .numpy as jnp
2627from vmoe import moe
@@ -51,9 +52,13 @@ class SoftRouter(nn.Module):
5152 precision : jax .lax .Precision = jax .lax .Precision .DEFAULT
5253 partition_spec : Optional [jax .sharding .PartitionSpec ] = None
5354 compute_similarity_metrics : bool = True
55+ partitioning_rules : Optional [Mapping [str , Tuple ]] = None # pylint: disable=g-bare-generic
5456
5557 @nn .compact
5658 def __call__ (self , inputs : Array ) -> Tuple [BaseDispatcher , Dict [str , Array ]]:
59+ if self .partitioning_rules :
60+ inputs = nn_partitioning .with_sharding_constraint (
61+ inputs , self .partitioning_rules ['inputs' ])
5762 # Normalize inputs to have unit norm.
5863 dtype = self .dtype or inputs .dtype
5964 inputs = normalize (inputs .astype (dtype ), axis = - 1 )
@@ -63,6 +68,10 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
6368 num_slots = moe .compute_capacity (
6469 group_size , self .num_experts , self .capacity_factor ,
6570 ceil_or_round = 'round' , multiple_of = 1 )
71+ logging .info (
72+ 'With num_tokens=%d, num_experts=%d, num_slots=%d, '
73+ 'capacity_factor=%f.' , group_size , self .num_experts ,
74+ num_slots , self .capacity_factor )
6675 else :
6776 num_slots = self .num_slots
6877 actual_capacity_factor = self .num_experts * num_slots / group_size
@@ -71,11 +80,22 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
7180 '%sWith num_tokens=%d, num_experts=%d and num_slots=%d, the actual '
7281 'capacity_factor is %f.' , pre , group_size , self .num_experts ,
7382 self .num_slots , actual_capacity_factor )
74- mu = self .param ('mu' , self .mu_init , (dim , self .num_experts , num_slots ))
83+ if self .partitioning_rules :
84+ mu = nn_partitioning .param_with_axes (
85+ 'mu' , self .mu_init , (dim , self .num_experts , num_slots ),
86+ axes = self .partitioning_rules ['mu' ])
87+ else :
88+ mu = self .param ('mu' , self .mu_init , (dim , self .num_experts , num_slots ))
7589 mu = normalize (mu .astype (dtype ), axis = 0 )
7690 self .sow ('intermediates' , 'mu_unit' , mu )
7791 # Scale inputs/mu before computing the logits.
78- scale = self .param ('scale' , self .scale_init , ()).astype (dtype )
92+ if self .partitioning_rules :
93+ scale = nn_partitioning .param_with_axes (
94+ 'scale' , self .scale_init , (),
95+ axes = self .partitioning_rules ['scale' ])
96+ else :
97+ scale = self .param ('scale' , self .scale_init , ())
98+ scale = scale .astype (dtype )
7999 if inputs .size < mu .size :
80100 inputs = inputs * scale
81101 else :
@@ -89,6 +109,9 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
89109 # Compute router logits between pairs of items (m) and total slots (n * p),
90110 # independently on each group (g).
91111 logits = jnp .einsum ('gmd,dnp->gmnp' , inputs , mu , precision = self .precision )
112+ if self .partitioning_rules :
113+ logits = nn_partitioning .with_sharding_constraint (
114+ logits , self .partitioning_rules ['logits' ])
92115 logits = self .add_noise (logits )
93116 # Each slot takes a convex combination of the inputs.
94117 dispatch_weights = jax .nn .softmax (logits , axis = 1 )
0 commit comments