Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,14 +664,20 @@ class Experimental:
needs to ensure that the path can be imported.
"""

# "none", "all", "only_fsdp"
bucket_all_gathers_fx: str | None = None

# "none", "all"
bucket_reduce_scatters_fx: str | None = None

reorder_for_compute_comm_overlap: bool = False
"""
Whether to enable inductor comm reordering passes
"""

reorder_for_compute_comm_overlap_passes: list[str] = field(
default_factory=lambda: [
"sink_waits",
"sink_waits_iterative",
"reorder_communication_preserving_peak_memory",
]
)
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,16 @@ def __init__(self, job_config: JobConfig):
# TODO(whc)
# I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering
torch._inductor.config.force_disable_caches = True
# this is necessary for working with reordering passes. Just leave it set for all the jobs for now.
torch._inductor.config.allow_buffer_reuse = False

# allow configuring inductor comms optimizations from torchtitan commandline
torch._inductor.config.bucket_all_gathers_fx = (
job_config.experimental.bucket_all_gathers_fx
)
torch._inductor.config.bucket_reduce_scatters_fx = (
job_config.experimental.bucket_reduce_scatters_fx
)
torch._inductor.config.reorder_for_compute_comm_overlap = (
job_config.experimental.reorder_for_compute_comm_overlap
)
Expand Down
Loading