Skip to content

Commit e8c4c04

Browse files
committed
misc
ghstack-source-id: dcf962b Pull-Request: #1892
1 parent 2a2ecdf commit e8c4c04

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,27 @@ def build_mesh(self) -> "ParallelDims":
7373
pp: For PP.
7474
dp_replicate: For DDP or HSDP replicate dimension.
7575
dp_shard_cp: For FSDP or HSDP shard dimension. This includes
76-
``cp`` even if ``cp`` is 1, so we just use the name
77-
``dp_shard_cp``. As a result, we always use the name
78-
``dp_shard_cp`` and ``dp_shard`` is not created as a
79-
dimension.
76+
``cp`` even if ``cp`` is 1. As a result, we always
77+
use the name ``dp_shard_cp``, and ``dp_shard`` is not
78+
created as a dimension.
8079
dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``,
8180
``dp_shard``, and ``cp`` as all of them are data parallelisms.
82-
dp: This is used by data loading. It includes both ``dp_replicate``
83-
and ``dp_shard``.
84-
The naming can be confusing; ``batch`` could be a better name.
81+
dp: This is used by data loading to decide the global batch size and
82+
which part of data this raunk should read. This dim includes both
83+
``dp_replicate`` and ``dp_shard``.
84+
The name is confusing; ``batch`` could be a better name.
8585
cp: For CP.
8686
tp: For TP.
8787
ep: For EP.
88-
dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region.
88+
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region.
8989
9090
Note: These dimensions won't exist at the same time. If we consider
91-
unflatten() operator only, following are all the meshes required
91+
the unflatten() operator only, the following are all the meshes required
9292
assuming all degrees are > 1 except for ``pp``:
9393
94-
["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader
95-
doesn't need it.
96-
["dp_cp", "tp"]: loss computation
94+
["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader
95+
doesn't need it for communication.
96+
["dp_cp", "tp"]: Loss computation.
9797
["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation.
9898
["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp.
9999
["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1.
@@ -102,7 +102,7 @@ def build_mesh(self) -> "ParallelDims":
102102
For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"].
103103
So we don't actually need to create ["dp_cp", "tp"].
104104
105-
But there are some meses we MUST create if that mesh will be used for a
105+
But there are some meshes we MUST create if that mesh will be used for a
106106
parameter. So Non-EP-region-computation mesh and EP-region-computation mesh
107107
are required.
108108
"""

0 commit comments

Comments
 (0)