Skip to content

Commit 9082c19

Browse files
committed
[autoparallel] Add experimental config to enable autoparallel_asynctp
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
1 parent db22479 commit 9082c19

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

torchtitan/config/job_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ class Experimental:
739739

740740
enable_simplefsdp_passes: bool = False
741741

742+
enable_autoparallel_asynctp: bool = False
743+
744+
742745
@dataclass
743746
class Validation:
744747
enable: bool = False

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from autoparallel.api import AutoParallel
1212

13-
from torch.distributed import DeviceMesh
1413
from torch.distributed.fsdp import MixedPrecisionPolicy
1514
from torch.distributed.tensor.placement_types import Replicate, Shard
1615

@@ -33,6 +32,7 @@ def parallelize_llama(
3332
the model must fit on GPU or CPU memory.
3433
"""
3534
world_mesh = parallel_dims.world_mesh
35+
3636
def input_fn():
3737
global_batch_size = job_config.training.global_batch_size
3838
if global_batch_size < 0:
@@ -62,6 +62,27 @@ def input_fn():
6262
lambda bucket_idx: 1000 / parallel_dims.tp
6363
)
6464

65+
if job_config.experimental.enable_autoparallel_asynctp:
66+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
67+
68+
assert "tp" in world_mesh.mesh_dim_names
69+
enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name)
70+
torch._inductor.config._micro_pipeline_tp = False
71+
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
72+
from autoparallel.asynctp import micro_pipeline_tp_pass
73+
74+
existing_post_grad_custom_post_pass = (
75+
torch._inductor.config.post_grad_custom_post_pass
76+
)
77+
78+
def _pass(graph):
79+
if existing_post_grad_custom_post_pass is not None:
80+
existing_post_grad_custom_post_pass(graph)
81+
82+
micro_pipeline_tp_pass(graph, None)
83+
84+
torch._inductor.config.post_grad_custom_post_pass = _pass
85+
6586
# bail out
6687
# model = model_fn()
6788
# return model
@@ -78,6 +99,7 @@ def input_fn():
7899
world_mesh,
79100
mp_policy=mp_policy,
80101
compile=job_config.compile,
102+
repeated_subgraphs=True,
81103
) as autop:
82104
autop.add_parameter_memory_constraint(low=None, high=None)
83105

@@ -101,7 +123,8 @@ def input_fn():
101123
)
102124
out_sharding = x_sharding
103125
loss_parallel_enabled = (
104-
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
126+
parallel_dims.tp_enabled
127+
and not job_config.parallelism.disable_loss_parallel
105128
)
106129
if loss_parallel_enabled:
107130
out_sharding = tuple(

0 commit comments

Comments
 (0)