5050
5151
5252def to_local_if_dtensor (tensor : Tensor | DTensor ) -> Tensor :
53- with torch .no_grad ():
54- return tensor .to_local () if isinstance (tensor , DTensor ) else tensor
53+ return tensor .to_local () if isinstance (tensor , DTensor ) else tensor
5554
5655
5756def device_mesh_has_dim (mesh : DeviceMesh , dim_name : str ) -> bool :
@@ -90,11 +89,12 @@ def get_grad_norm_fp32(
9089 data_parallel_group : ProcessGroup ,
9190 model_parallel_group : ProcessGroup ,
9291 norm_type : float = 2.0 ,
92+ offload_params : bool = False ,
9393) -> float :
9494 if isinstance (grads_for_norm , Tensor ):
9595 grads_for_norm = [grads_for_norm ]
9696
97- grads_for_norm = [to_local_if_dtensor (grad ) for grad in grads_for_norm ]
97+ grads_for_norm = [to_local_if_dtensor (grad ). detach () for grad in grads_for_norm ]
9898
9999 norm_type = float (norm_type )
100100 total_norm = 0.0
@@ -105,25 +105,26 @@ def get_grad_norm_fp32(
105105 device = current_platform .current_device ()
106106
107107 if norm_type == torch .inf :
108- norms = [grad .abs ().max () for grad in grads_for_norm ]
109- total_norm = torch . max (torch . stack ( norms ) ) if norms else 0.0
108+ norms = [grad .abs ().max (). item () for grad in grads_for_norm ]
109+ total_norm = max (norms ) if norms else 0.0
110110 total_norm_cuda = torch .tensor (
111111 [float (total_norm )], dtype = torch .float , device = device
112112 )
113113 if data_parallel_group :
114- torch . distributed .all_reduce (
114+ dist .all_reduce (
115115 total_norm_cuda ,
116- op = torch . distributed .ReduceOp .MAX ,
116+ op = dist .ReduceOp .MAX ,
117117 group = data_parallel_group ,
118118 )
119- torch . distributed .all_reduce (
119+ dist .all_reduce (
120120 total_norm_cuda ,
121- op = torch . distributed .ReduceOp .MAX ,
121+ op = dist .ReduceOp .MAX ,
122122 group = model_parallel_group ,
123123 )
124- total_norm = float (total_norm_cuda [ 0 ] .item ())
124+ total_norm = float (total_norm_cuda .item ())
125125 else :
126- if norm_type == 2.0 :
126+ if norm_type == 2.0 and not offload_params :
127+ # Use multi_tensor_applier for better performance when grads are on GPU
127128 dummy_overflow_buf = torch .tensor ([0 ], dtype = torch .int , device = device )
128129 grad_norm , _ = multi_tensor_applier (
129130 l2_norm_impl ,
@@ -133,20 +134,23 @@ def get_grad_norm_fp32(
133134 )
134135 total_norm_cuda = grad_norm ** norm_type
135136 else :
136- total_norm_cuda = torch . tensor ([ 0.0 ], dtype = torch . float , device = device )
137+ total_norm = 0.0
137138 for grad in grads_for_norm :
138- grad_norm = torch .norm (grad , norm_type )
139- total_norm_cuda += grad_norm ** norm_type
139+ grad_norm = torch .norm (grad , norm_type ).item ()
140+ total_norm += grad_norm ** norm_type
141+ total_norm_cuda = torch .tensor (
142+ [float (total_norm )], dtype = torch .float , device = device
143+ )
140144
141145 if data_parallel_group :
142- torch . distributed .all_reduce (
146+ dist .all_reduce (
143147 total_norm_cuda ,
144- op = torch . distributed .ReduceOp .SUM ,
148+ op = dist .ReduceOp .SUM ,
145149 group = data_parallel_group ,
146150 )
147- torch . distributed .all_reduce (
151+ dist .all_reduce (
148152 total_norm_cuda ,
149- op = torch . distributed .ReduceOp .SUM ,
153+ op = dist .ReduceOp .SUM ,
150154 group = model_parallel_group ,
151155 )
152156 total_norm = float (total_norm_cuda .item ()) ** (1.0 / norm_type )
@@ -159,48 +163,68 @@ def clip_grad_by_total_norm_fp32(
159163 parameters : list [nn .Parameter ],
160164 max_norm : int | float ,
161165 total_norm : float ,
162- ):
166+ ) -> None :
167+ clip_coeff = max_norm / (total_norm + 1.0e-6 )
168+ if clip_coeff >= 1.0 :
169+ return
170+
163171 # dtype -> grad
164172 grads = defaultdict (list )
173+ cpu_grads = defaultdict (list )
165174 for param in parameters :
166175 if param .grad is not None :
167- # For naive FSDP, lm_head has bf16 grad while others have fp32 grad
168176 grad = to_local_if_dtensor (param .grad ).detach ()
169- grads [grad .dtype ].append (grad )
177+ if grad .device .type != "cpu" :
178+ grads [grad .dtype ].append (grad )
179+ else :
180+ cpu_grads [grad .dtype ].append (grad )
170181
171- assert len (grads ) > 0 , len (grads )
172- clip_coeff = max_norm / (total_norm + 1.0e-6 )
173- if clip_coeff < 1.0 :
174- for dtype , _grads in grads .items ():
175- dummy_overflow_buf = torch .tensor (
176- [0 ], dtype = torch .int , device = current_platform .device_type
182+ if len (grads ) == 0 and len (cpu_grads ) == 0 :
183+ return
184+
185+ from .multi_tensor_apply import (
186+ local_multi_tensor_applier ,
187+ local_multi_tensor_scale ,
188+ )
189+
190+ # Clip GPU grads
191+ for dtype , _grads in grads .items ():
192+ dummy_overflow_buf = torch .tensor (
193+ [0 ], dtype = torch .int , device = current_platform .device_type
194+ )
195+ # For naive FSDP, lm_head has bf16 grad while others have fp32 grad
196+ if dtype == torch .float32 :
197+ multi_tensor_applier (
198+ multi_tensor_scale_impl ,
199+ dummy_overflow_buf ,
200+ [_grads , _grads ],
201+ clip_coeff ,
202+ )
203+ else :
204+ local_multi_tensor_applier (
205+ local_multi_tensor_scale ,
206+ dummy_overflow_buf ,
207+ [_grads , _grads ],
208+ clip_coeff ,
177209 )
178- if dtype == torch .float32 :
179- multi_tensor_applier (
180- multi_tensor_scale_impl ,
181- dummy_overflow_buf ,
182- [_grads , _grads ],
183- clip_coeff ,
184- )
185- else :
186- from .multi_tensor_apply import (
187- local_multi_tensor_applier ,
188- local_multi_tensor_scale ,
189- )
190210
191- local_multi_tensor_applier (
192- local_multi_tensor_scale ,
193- dummy_overflow_buf ,
194- [_grads , _grads ],
195- clip_coeff ,
196- )
211+ # Clip CPU grads
212+ dummy_overflow_buf = torch .tensor ([0 ], dtype = torch .int , device = "cpu" )
213+ for _grads in cpu_grads .values ():
214+ local_multi_tensor_applier (
215+ local_multi_tensor_scale ,
216+ dummy_overflow_buf ,
217+ [_grads , _grads ],
218+ clip_coeff ,
219+ )
197220
198221
199222def fsdp2_clip_grad_norm (
200223 parameters : list [nn .Parameter ],
201224 nd_device_mesh : DeviceMesh ,
202225 max_norm : float ,
203226 norm_type : float = 2.0 ,
227+ offload_params : bool = False ,
204228) -> float :
205229 if norm_type <= 0 and norm_type != float ("inf" ):
206230 raise ValueError (
@@ -215,7 +239,11 @@ def fsdp2_clip_grad_norm(
215239 grads_for_norm = get_main_grads_for_grad_norm (parameters , tensor_parallel_rank )
216240
217241 grad_norm = get_grad_norm_fp32 (
218- grads_for_norm , fsdp_group , tp_group , norm_type = norm_type
242+ grads_for_norm ,
243+ fsdp_group ,
244+ tp_group ,
245+ norm_type = norm_type ,
246+ offload_params = offload_params ,
219247 )
220248
221249 if parameters :
0 commit comments