@@ -121,7 +121,6 @@ def unflatten_mesh(
121121 )
122122
123123 batch = self .dp_replicate * self .dp_shard
124- loss = self .dp_replicate * self .dp_shard * self .cp
125124 fsdp = self .dp_shard * self .cp
126125 efsdp = fsdp * self .tp // (self .etp * self .ep )
127126
@@ -145,12 +144,12 @@ def unflatten_mesh(
145144 (self .pp , self .dp_replicate , efsdp , self .ep , self .etp ),
146145 )
147146
148- # We have created all the required 1D meshes. This part is to create the
149- # all the 2D meshes. We pre-created 2D meshes and error out if the users
150- # try to access a 2D mesh that is not pre-created.
151- hsdp_mesh = dense_mesh [ "dp_replicate" , "fsdp" ]
152- ehsdp_mesh = sparse_mesh [ "dp_replicate" , "efsdp" ]
153- ep_etp_mesh = sparse_mesh [ "ep" , "etp" ]
147+ self . _global_meshes = {
148+ "dataloading" : dataloading_mesh ,
149+ "loss" : loss_mesh ,
150+ "dense" : dense_mesh ,
151+ "sparse" : sparse_mesh ,
152+ }
154153
155154 self ._meshes = {
156155 "pp" : dataloading_mesh ["pp" ],
@@ -163,9 +162,6 @@ def unflatten_mesh(
163162 "ep" : sparse_mesh ["ep" ],
164163 "efsdp" : sparse_mesh ["efsdp" ],
165164 "etp" : sparse_mesh ["etp" ],
166- "dp_replicate_fsdp" : hsdp_mesh ,
167- "dp_replicate_efsdp" : ehsdp_mesh ,
168- "ep_etp" : ep_etp_mesh ,
169165 }
170166
171167 # Validate mesh sizes
@@ -191,19 +187,10 @@ def _validate_meshes(self):
191187 "ep" : self .ep ,
192188 "efsdp" : self .dp_shard * self .cp * self .tp // (self .etp * self .ep ),
193189 "etp" : self .etp ,
194- "dp_replicate_fsdp" : (self .dp_replicate , self .dp_shard * self .cp ),
195- "dp_replicate_efsdp" : (
196- self .dp_replicate ,
197- self .dp_shard * self .cp * self .tp // (self .etp * self .ep ),
198- ),
199- "ep_etp" : (self .ep , self .etp ),
200190 }
201191
202192 for mesh_name , expected_size in expected_sizes .items ():
203- if isinstance (expected_size , tuple ):
204- actual_size = self ._meshes [mesh_name ].shape
205- else :
206- actual_size = self ._meshes [mesh_name ].size ()
193+ actual_size = self ._meshes [mesh_name ].size ()
207194 assert actual_size == expected_size , (
208195 f"Mesh '{ mesh_name } ' has unexpected size: "
209196 f"expected { expected_size } , got { actual_size } "
@@ -232,17 +219,24 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None:
232219 if isinstance (dims , str ):
233220 dims = [dims ]
234221
235- mesh_name = "_" . join ( dims )
236- if mesh_name not in self ._meshes :
237- raise ValueError (
238- f"Invalid mesh dim: '{ mesh_name } '. "
239- f"Valid dimensions are: { list (self ._meshes .keys ())} "
240- )
222+ for mesh_name in dims :
223+ if mesh_name not in self ._meshes :
224+ raise ValueError (
225+ f"Invalid mesh dim: '{ mesh_name } '. "
226+ f"Valid dimensions are: { list (self ._meshes .keys ())} "
227+ )
241228
242229 if any (not self ._mesh_exist (dim , self ._meshes [dim ].size ()) for dim in dims ):
243230 return None
244231
245- return self ._meshes [mesh_name ]
232+ if len (dims ) == 1 :
233+ return self ._meshes [dims [0 ]]
234+ else :
235+ for global_mesh in self ._global_meshes .values ():
236+ if not set (dims ).issubset (set (global_mesh .mesh_dim_names )):
237+ continue
238+ return global_mesh [tuple (dims )]
239+ raise ValueError (f"Invalid mesh name combinations { dims } ." )
246240
247241 def get_all_meshes (self , one_dimensioal_only : bool = True ) -> dict [str , DeviceMesh ]:
248242 if not self ._meshes :
0 commit comments