Skip to content

Commit 045b159

Browse files
committed
Undo changes directory changes made in torchtitan.distributed.utils.py
1 parent b10c0bc commit 045b159

File tree

1 file changed

+11
-29
lines changed

1 file changed

+11
-29
lines changed

torchtitan/distributed/utils.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def set_determinism(
106106
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
107107
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
108108

109+
# Ensure flex_attention is compiled without max-autotune. This is needed to ensure
110+
# reproducibility, since the autotune results may not be deterministic.
111+
from torch.nn.attention.flex_attention import flex_attention
112+
113+
from torchtitan.models.attention import FlexAttentionWrapper
114+
115+
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)
116+
109117
if not world_mesh:
110118
if seed is not None:
111119
torch.manual_seed(seed)
@@ -185,28 +193,14 @@ def create_context_parallel_ctx(
185193
)
186194

187195

188-
def get_train_context(
189-
enable_loss_parallel: bool, enable_compiled_autograd: bool
190-
) -> Generator[None, None, None]:
196+
def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]:
191197
@contextlib.contextmanager
192198
def context(cp_context: Generator[None, None, None] | None = None):
193199
with contextlib.ExitStack() as stack:
194200
if enable_loss_parallel:
195201
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
196202

197-
if enable_compiled_autograd:
198-
stack.enter_context(
199-
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
200-
)
201-
202-
if cp_context is not None:
203-
from torch.nn.attention import SDPBackend
204-
205-
from torchtitan.models.attention import ScaledDotProductAttention
206-
207-
if SDPBackend.MATH in ScaledDotProductAttention.backends:
208-
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
209-
203+
if cp_context:
210204
stack.enter_context(cp_context)
211205

212206
yield
@@ -274,13 +268,7 @@ def _get_distributed_backend(enable_cpu_backend):
274268
if comm_config.trace_buf_size > 0:
275269
# dump on timeout by default if trace buffer is enabled
276270
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
277-
# ROCm runner doesn't have write permissions for current working directory.
278-
# Hence, using HOME directory to save results.
279-
if base_folder and os.access(base_folder, os.W_OK):
280-
dump_base = base_folder
281-
else:
282-
dump_base = os.path.expanduser("~")
283-
dump_dir = os.path.join(dump_base, comm_config.save_traces_folder)
271+
dump_dir = os.path.join(base_folder, comm_config.save_traces_folder)
284272
prefix = comm_config.save_traces_file_prefix
285273
os.makedirs(dump_dir, exist_ok=True)
286274
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/{prefix}")
@@ -455,9 +443,3 @@ def _clip_grad_norm_with_ep(
455443
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
456444

457445
return total_norm
458-
459-
460-
def _round_up(x: int, y: int) -> int:
461-
"""Round up x to the nearest multiple of y."""
462-
x_ceil_div_y = (x + y - 1) // y
463-
return x_ceil_div_y * y

0 commit comments

Comments
 (0)