@@ -73,27 +73,27 @@ def build_mesh(self) -> "ParallelDims":
7373 pp: For PP.
7474 dp_replicate: For DDP or HSDP replicate dimension.
7575 dp_shard_cp: For FSDP or HSDP shard dimension. This includes
76- ``cp`` even if ``cp`` is 1, so we just use the name
77- ``dp_shard_cp``. As a result, we always use the name
78- ``dp_shard_cp`` and ``dp_shard`` is not created as a
79- dimension.
76+ ``cp`` even if ``cp`` is 1. As a result, we always
77+ use the name ``dp_shard_cp``, and ``dp_shard`` is not
78+ created as a dimension.
8079 dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``,
8180 ``dp_shard``, and ``cp`` as all of them are data parallelisms.
82- dp: This is used by data loading. It includes both ``dp_replicate``
83- and ``dp_shard``.
84- The naming can be confusing; ``batch`` could be a better name.
81+ dp: This is used by data loading to decide the global batch size and
82+ which part of data this raunk should read. This dim includes both
83+ ``dp_replicate`` and ``dp_shard``.
84+ The name is confusing; ``batch`` could be a better name.
8585 cp: For CP.
8686 tp: For TP.
8787 ep: For EP.
88- dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region.
88+ dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region.
8989
9090 Note: These dimensions won't exist at the same time. If we consider
91- unflatten() operator only, following are all the meshes required
91+ the unflatten() operator only, the following are all the meshes required
9292 assuming all degrees are > 1 except for ``pp``:
9393
94- ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader
95- doesn't need it.
96- ["dp_cp", "tp"]: loss computation
94+ ["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader
95+ doesn't need it for communication .
96+ ["dp_cp", "tp"]: Loss computation.
9797 ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation.
9898 ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp.
9999 ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1.
@@ -102,7 +102,7 @@ def build_mesh(self) -> "ParallelDims":
102102 For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"].
103103 So we don't actually need to create ["dp_cp", "tp"].
104104
105- But there are some meses we MUST create if that mesh will be used for a
105+ But there are some meshes we MUST create if that mesh will be used for a
106106 parameter. So Non-EP-region-computation mesh and EP-region-computation mesh
107107 are required.
108108 """
0 commit comments