@@ -106,6 +106,14 @@ def set_determinism(
106106 # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
107107 os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
108108
109+ # Ensure flex_attention is compiled without max-autotune. This is needed to ensure
110+ # reproducibility, since the autotune results may not be deterministic.
111+ from torch .nn .attention .flex_attention import flex_attention
112+
113+ from torchtitan .models .attention import FlexAttentionWrapper
114+
115+ FlexAttentionWrapper ._compiled_flex_attn = torch .compile (flex_attention )
116+
109117 if not world_mesh :
110118 if seed is not None :
111119 torch .manual_seed (seed )
@@ -185,28 +193,14 @@ def create_context_parallel_ctx(
185193 )
186194
187195
188- def get_train_context (
189- enable_loss_parallel : bool , enable_compiled_autograd : bool
190- ) -> Generator [None , None , None ]:
196+ def get_train_context (enable_loss_parallel : bool ) -> Generator [None , None , None ]:
191197 @contextlib .contextmanager
192198 def context (cp_context : Generator [None , None , None ] | None = None ):
193199 with contextlib .ExitStack () as stack :
194200 if enable_loss_parallel :
195201 stack .enter_context (torch .distributed .tensor .parallel .loss_parallel ())
196202
197- if enable_compiled_autograd :
198- stack .enter_context (
199- torch ._dynamo .utils .maybe_enable_compiled_autograd (True )
200- )
201-
202- if cp_context is not None :
203- from torch .nn .attention import SDPBackend
204-
205- from torchtitan .models .attention import ScaledDotProductAttention
206-
207- if SDPBackend .MATH in ScaledDotProductAttention .backends :
208- ScaledDotProductAttention .backends .remove (SDPBackend .MATH )
209-
203+ if cp_context :
210204 stack .enter_context (cp_context )
211205
212206 yield
@@ -274,13 +268,7 @@ def _get_distributed_backend(enable_cpu_backend):
274268 if comm_config .trace_buf_size > 0 :
275269 # dump on timeout by default if trace buffer is enabled
276270 _warn_overwrite_env (DUMP_ON_TIMEOUT , "1" )
277- # ROCm runner doesn't have write permissions for current working directory.
278- # Hence, using HOME directory to save results.
279- if base_folder and os .access (base_folder , os .W_OK ):
280- dump_base = base_folder
281- else :
282- dump_base = os .path .expanduser ("~" )
283- dump_dir = os .path .join (dump_base , comm_config .save_traces_folder )
271+ dump_dir = os .path .join (base_folder , comm_config .save_traces_folder )
284272 prefix = comm_config .save_traces_file_prefix
285273 os .makedirs (dump_dir , exist_ok = True )
286274 _warn_overwrite_env (TRACE_FILE , f"{ dump_dir } /{ prefix } " )
@@ -455,9 +443,3 @@ def _clip_grad_norm_with_ep(
455443 torch .nn .utils .clip_grads_with_norm_ (non_ep_params , max_norm , total_norm , foreach )
456444
457445 return total_norm
458-
459-
460- def _round_up (x : int , y : int ) -> int :
461- """Round up x to the nearest multiple of y."""
462- x_ceil_div_y = (x + y - 1 ) // y
463- return x_ceil_div_y * y
0 commit comments