1010
1111from autoparallel .api import AutoParallel
1212
13- from torch .distributed import DeviceMesh
1413from torch .distributed .fsdp import MixedPrecisionPolicy
1514from torch .distributed .tensor .placement_types import Replicate , Shard
1615
@@ -33,6 +32,7 @@ def parallelize_llama(
3332 the model must fit on GPU or CPU memory.
3433 """
3534 world_mesh = parallel_dims .world_mesh
35+
3636 def input_fn ():
3737 global_batch_size = job_config .training .global_batch_size
3838 if global_batch_size < 0 :
@@ -62,6 +62,27 @@ def input_fn():
6262 lambda bucket_idx : 1000 / parallel_dims .tp
6363 )
6464
65+ if job_config .experimental .enable_autoparallel_asynctp :
66+ from torch .distributed ._symmetric_memory import enable_symm_mem_for_group
67+
68+ assert "tp" in world_mesh .mesh_dim_names
69+ enable_symm_mem_for_group (world_mesh ["tp" ].get_group ().group_name )
70+ torch ._inductor .config ._micro_pipeline_tp = False
71+ # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
72+ from autoparallel .asynctp import micro_pipeline_tp_pass
73+
74+ existing_post_grad_custom_post_pass = (
75+ torch ._inductor .config .post_grad_custom_post_pass
76+ )
77+
78+ def _pass (graph ):
79+ if existing_post_grad_custom_post_pass is not None :
80+ existing_post_grad_custom_post_pass (graph )
81+
82+ micro_pipeline_tp_pass (graph , None )
83+
84+ torch ._inductor .config .post_grad_custom_post_pass = _pass
85+
6586 # bail out
6687 # model = model_fn()
6788 # return model
@@ -78,6 +99,7 @@ def input_fn():
7899 world_mesh ,
79100 mp_policy = mp_policy ,
80101 compile = job_config .compile ,
102+ repeated_subgraphs = True ,
81103 ) as autop :
82104 autop .add_parameter_memory_constraint (low = None , high = None )
83105
@@ -101,7 +123,8 @@ def input_fn():
101123 )
102124 out_sharding = x_sharding
103125 loss_parallel_enabled = (
104- parallel_dims .tp_enabled and not job_config .parallelism .disable_loss_parallel
126+ parallel_dims .tp_enabled
127+ and not job_config .parallelism .disable_loss_parallel
105128 )
106129 if loss_parallel_enabled :
107130 out_sharding = tuple (
0 commit comments