Skip to content

Commit 3ab4bf9

Browse files
committed
[ap] Knobs to enable reorder/bucketing/async_tp passes
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
1 parent db22479 commit 3ab4bf9

File tree

1 file changed

+55
-2
lines changed

1 file changed

+55
-2
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from autoparallel.api import AutoParallel
1212

13-
from torch.distributed import DeviceMesh
1413
from torch.distributed.fsdp import MixedPrecisionPolicy
1514
from torch.distributed.tensor.placement_types import Replicate, Shard
1615

@@ -33,6 +32,7 @@ def parallelize_llama(
3332
the model must fit on GPU or CPU memory.
3433
"""
3534
world_mesh = parallel_dims.world_mesh
35+
3636
def input_fn():
3737
global_batch_size = job_config.training.global_batch_size
3838
if global_batch_size < 0:
@@ -62,6 +62,57 @@ def input_fn():
6262
lambda bucket_idx: 1000 / parallel_dims.tp
6363
)
6464

65+
enable_overlap_scheduling = False
66+
enable_overlap_scheduling_bucketing = False
67+
if enable_overlap_scheduling_bucketing:
68+
assert (
69+
enable_overlap_scheduling
70+
), "bucketing can not be used without overlap scheduling"
71+
enable_asynctp = False
72+
73+
if (
74+
enable_overlap_scheduling
75+
or enable_overlap_scheduling_bucketing
76+
or enable_asynctp
77+
):
78+
mesh = world_mesh
79+
torch._inductor.config.reorder_for_peak_memory = False
80+
torch._inductor.config.reorder_for_compute_comm_overlap = False
81+
torch._inductor.config.allow_buffer_reuse = False
82+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
83+
enable_overlap_scheduling_bucketing
84+
)
85+
86+
if enable_asynctp:
87+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
88+
89+
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
90+
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
91+
torch._inductor.config._micro_pipeline_tp = False
92+
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
93+
# TODO: Switch to Inductor AsyncTP passes, when all additions landed.
94+
from autoparallel.asynctp import micro_pipeline_tp_pass
95+
96+
existing_post_grad_custom_post_pass = (
97+
torch._inductor.config.post_grad_custom_post_pass
98+
)
99+
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler
100+
101+
def _pass(graph):
102+
if existing_post_grad_custom_post_pass is not None:
103+
existing_post_grad_custom_post_pass(graph)
104+
105+
collective_info = None
106+
if enable_overlap_scheduling:
107+
overlap_scheduler = OverlapScheduler(graph.owning_module)
108+
overlap_scheduler.run()
109+
collective_info = overlap_scheduler.collective_info
110+
111+
if enable_asynctp:
112+
micro_pipeline_tp_pass(graph, collective_info)
113+
114+
torch._inductor.config.post_grad_custom_post_pass = _pass
115+
65116
# bail out
66117
# model = model_fn()
67118
# return model
@@ -78,6 +129,7 @@ def input_fn():
78129
world_mesh,
79130
mp_policy=mp_policy,
80131
compile=job_config.compile,
132+
repeated_subgraphs=True,
81133
) as autop:
82134
autop.add_parameter_memory_constraint(low=None, high=None)
83135

@@ -101,7 +153,8 @@ def input_fn():
101153
)
102154
out_sharding = x_sharding
103155
loss_parallel_enabled = (
104-
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
156+
parallel_dims.tp_enabled
157+
and not job_config.parallelism.disable_loss_parallel
105158
)
106159
if loss_parallel_enabled:
107160
out_sharding = tuple(

0 commit comments

Comments
 (0)