1010
1111from autoparallel .api import AutoParallel
1212
13- from torch .distributed import DeviceMesh
1413from torch .distributed .fsdp import MixedPrecisionPolicy
1514from torch .distributed .tensor .placement_types import Replicate , Shard
1615
@@ -33,6 +32,7 @@ def parallelize_llama(
3332 the model must fit on GPU or CPU memory.
3433 """
3534 world_mesh = parallel_dims .world_mesh
35+
3636 def input_fn ():
3737 global_batch_size = job_config .training .global_batch_size
3838 if global_batch_size < 0 :
@@ -62,6 +62,57 @@ def input_fn():
6262 lambda bucket_idx : 1000 / parallel_dims .tp
6363 )
6464
65+ enable_overlap_scheduling = False
66+ enable_overlap_scheduling_bucketing = False
67+ if enable_overlap_scheduling_bucketing :
68+ assert (
69+ enable_overlap_scheduling
70+ ), "bucketing can not be used without overlap scheduling"
71+ enable_asynctp = False
72+
73+ if (
74+ enable_overlap_scheduling
75+ or enable_overlap_scheduling_bucketing
76+ or enable_asynctp
77+ ):
78+ mesh = world_mesh
79+ torch ._inductor .config .reorder_for_peak_memory = False
80+ torch ._inductor .config .reorder_for_compute_comm_overlap = False
81+ torch ._inductor .config .allow_buffer_reuse = False
82+ torch ._inductor .config .test_configs .aten_fx_overlap_preserving_bucketing = (
83+ enable_overlap_scheduling_bucketing
84+ )
85+
86+ if enable_asynctp :
87+ from torch .distributed ._symmetric_memory import enable_symm_mem_for_group
88+
89+ enable_symm_mem_for_group (mesh ["tp" ].get_group ().group_name )
90+ enable_symm_mem_for_group (mesh ["dp" ].get_group ().group_name )
91+ torch ._inductor .config ._micro_pipeline_tp = False
92+ # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
93+ # TODO: Switch to Inductor AsyncTP passes, when all additions landed.
94+ from autoparallel .asynctp import micro_pipeline_tp_pass
95+
96+ existing_post_grad_custom_post_pass = (
97+ torch ._inductor .config .post_grad_custom_post_pass
98+ )
99+ from torch ._inductor .fx_passes .overlap_scheduling import OverlapScheduler
100+
101+ def _pass (graph ):
102+ if existing_post_grad_custom_post_pass is not None :
103+ existing_post_grad_custom_post_pass (graph )
104+
105+ collective_info = None
106+ if enable_overlap_scheduling :
107+ overlap_scheduler = OverlapScheduler (graph .owning_module )
108+ overlap_scheduler .run ()
109+ collective_info = overlap_scheduler .collective_info
110+
111+ if enable_asynctp :
112+ micro_pipeline_tp_pass (graph , collective_info )
113+
114+ torch ._inductor .config .post_grad_custom_post_pass = _pass
115+
65116 # bail out
66117 # model = model_fn()
67118 # return model
@@ -78,6 +129,7 @@ def input_fn():
78129 world_mesh ,
79130 mp_policy = mp_policy ,
80131 compile = job_config .compile ,
132+ repeated_subgraphs = True ,
81133 ) as autop :
82134 autop .add_parameter_memory_constraint (low = None , high = None )
83135
@@ -101,7 +153,8 @@ def input_fn():
101153 )
102154 out_sharding = x_sharding
103155 loss_parallel_enabled = (
104- parallel_dims .tp_enabled and not job_config .parallelism .disable_loss_parallel
156+ parallel_dims .tp_enabled
157+ and not job_config .parallelism .disable_loss_parallel
105158 )
106159 if loss_parallel_enabled :
107160 out_sharding = tuple (
0 commit comments