Skip to content

Commit 043b496

Browse files
committed
fix: fix CPU offload for FSDP
Updates the FSDP engine to correctly handle DTensors residing on CPU during weight synchronization, ensuring proper materialization to full tensors. Refactors weight gathering logic to batch parameter processing, improving memory efficiency during distributed updates. Modifies gradient norm calculation and clipping functions to support CPU-resident gradients, ensuring operations are applied to the correct device types.
1 parent 601afa7 commit 043b496

File tree

2 files changed

+131
-61
lines changed

2 files changed

+131
-61
lines changed

areal/engine/fsdp_engine.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(self, config: TrainEngineConfig):
122122
self.parallel_helper: ParallelHelper
123123
self.world_mesh: DeviceMesh
124124

125+
self.fsdp_group: dist.ProcessGroup
125126
self.dp_group: dist.ProcessGroup
126127
self.sp_group: dist.ProcessGroup
127128
self.mp_group: dist.ProcessGroup
@@ -192,6 +193,7 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None
192193

193194
self.world_mesh = self.parallel_helper.world_mesh
194195

196+
self.fsdp_group = self.world_mesh["dp_sp"].get_group()
195197
self.dp_group = self.world_mesh["dp"].get_group()
196198
self.sp_group = self.world_mesh["sp"].get_group()
197199

@@ -449,6 +451,44 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
449451

450452
fut.result()
451453

454+
def _dtensor_to_full_tensor(self, dtensor: DTensor) -> torch.Tensor:
455+
"""Convert a DTensor to a full tensor, handling CPU offloaded tensors."""
456+
local_tensor = dtensor.to_local()
457+
if local_tensor.device.type != "cpu":
458+
return dtensor.full_tensor()
459+
460+
device_mesh = dtensor.device_mesh
461+
placements = dtensor.placements
462+
temp_dtensor = DTensor.from_local(
463+
local_tensor,
464+
device_mesh=device_mesh,
465+
placements=placements,
466+
)
467+
return temp_dtensor.full_tensor()
468+
469+
def _materialize_and_update_bucket(
470+
self,
471+
meta: WeightUpdateMeta,
472+
named_params: list[tuple[str, nn.Parameter]],
473+
):
474+
"""Materialize DTensors to full tensors and broadcast to inference engine."""
475+
main_rank: bool = dist.get_rank() == 0
476+
named_tensors = []
477+
478+
for name, param in named_params:
479+
if isinstance(param.data, DTensor):
480+
tensor = self._dtensor_to_full_tensor(param.data)
481+
else:
482+
tensor = param.data
483+
if tensor.device.type == "cpu":
484+
tensor = tensor.to(current_platform.device_type)
485+
486+
if main_rank:
487+
named_tensors.append((name, tensor))
488+
489+
if named_tensors:
490+
self._update_bucket_weights_from_distributed(meta, named_tensors)
491+
452492
@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
453493
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
454494
"""Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
@@ -459,32 +499,33 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
459499
dist.barrier(group=self.cpu_group)
460500

461501
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
502+
fsdp_world_size = dist.get_world_size(self.fsdp_group)
462503

463504
buffer_size = 0
464-
named_tensors = []
505+
named_params_bucket: list[tuple[str, nn.Parameter]] = []
465506

466507
for name, param in self.get_model_name_parameters():
467508
if isinstance(param.data, DTensor):
468-
tensor = param.data.full_tensor()
509+
local_tensor = param.data.to_local()
510+
tensor_size = local_tensor.numel() * local_tensor.element_size()
511+
tensor_size *= fsdp_world_size
469512
else:
470-
tensor = param.data
513+
tensor_size = param.data.numel() * param.data.element_size()
471514

472-
# Ranks other than 0 only help to get the full tensor
473-
if dist.get_rank() != 0:
474-
continue
475-
476-
tensor_size = tensor.numel() * tensor.element_size()
477-
478-
if tensor_size + buffer_size > weight_chunked_mem_size:
479-
self._update_bucket_weights_from_distributed(meta, named_tensors)
515+
if (
516+
tensor_size + buffer_size > weight_chunked_mem_size
517+
and named_params_bucket
518+
):
519+
self._materialize_and_update_bucket(meta, named_params_bucket)
520+
named_params_bucket = []
480521
buffer_size = 0
481522

482-
named_tensors.append((name, tensor))
523+
named_params_bucket.append((name, param))
483524
buffer_size += tensor_size
484525

485-
# Only rank-0 CAN contain named tensors here
486-
if named_tensors:
487-
self._update_bucket_weights_from_distributed(meta, named_tensors)
526+
# Process remaining parameters
527+
if named_params_bucket:
528+
self._materialize_and_update_bucket(meta, named_params_bucket)
488529

489530
dist.barrier(group=self.cpu_group)
490531

@@ -808,6 +849,7 @@ def train_batch(
808849
list(self.model.parameters()),
809850
self.world_mesh,
810851
max_norm=self.optimizer_config.gradient_clipping,
852+
offload_params=self.config.fsdp.offload_params,
811853
)
812854

813855
if not math.isfinite(grad_norm):

areal/utils/fsdp/grad.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@
5050

5151

5252
def to_local_if_dtensor(tensor: Tensor | DTensor) -> Tensor:
53-
with torch.no_grad():
54-
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
53+
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
5554

5655

5756
def device_mesh_has_dim(mesh: DeviceMesh, dim_name: str) -> bool:
@@ -90,11 +89,12 @@ def get_grad_norm_fp32(
9089
data_parallel_group: ProcessGroup,
9190
model_parallel_group: ProcessGroup,
9291
norm_type: float = 2.0,
92+
offload_params: bool = False,
9393
) -> float:
9494
if isinstance(grads_for_norm, Tensor):
9595
grads_for_norm = [grads_for_norm]
9696

97-
grads_for_norm = [to_local_if_dtensor(grad) for grad in grads_for_norm]
97+
grads_for_norm = [to_local_if_dtensor(grad).detach() for grad in grads_for_norm]
9898

9999
norm_type = float(norm_type)
100100
total_norm = 0.0
@@ -105,25 +105,26 @@ def get_grad_norm_fp32(
105105
device = current_platform.current_device()
106106

107107
if norm_type == torch.inf:
108-
norms = [grad.abs().max() for grad in grads_for_norm]
109-
total_norm = torch.max(torch.stack(norms)) if norms else 0.0
108+
norms = [grad.abs().max().item() for grad in grads_for_norm]
109+
total_norm = max(norms) if norms else 0.0
110110
total_norm_cuda = torch.tensor(
111111
[float(total_norm)], dtype=torch.float, device=device
112112
)
113113
if data_parallel_group:
114-
torch.distributed.all_reduce(
114+
dist.all_reduce(
115115
total_norm_cuda,
116-
op=torch.distributed.ReduceOp.MAX,
116+
op=dist.ReduceOp.MAX,
117117
group=data_parallel_group,
118118
)
119-
torch.distributed.all_reduce(
119+
dist.all_reduce(
120120
total_norm_cuda,
121-
op=torch.distributed.ReduceOp.MAX,
121+
op=dist.ReduceOp.MAX,
122122
group=model_parallel_group,
123123
)
124-
total_norm = float(total_norm_cuda[0].item())
124+
total_norm = float(total_norm_cuda.item())
125125
else:
126-
if norm_type == 2.0:
126+
if norm_type == 2.0 and not offload_params:
127+
# Use multi_tensor_applier for better performance when grads are on GPU
127128
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device)
128129
grad_norm, _ = multi_tensor_applier(
129130
l2_norm_impl,
@@ -133,20 +134,23 @@ def get_grad_norm_fp32(
133134
)
134135
total_norm_cuda = grad_norm**norm_type
135136
else:
136-
total_norm_cuda = torch.tensor([0.0], dtype=torch.float, device=device)
137+
total_norm = 0.0
137138
for grad in grads_for_norm:
138-
grad_norm = torch.norm(grad, norm_type)
139-
total_norm_cuda += grad_norm**norm_type
139+
grad_norm = torch.norm(grad, norm_type).item()
140+
total_norm += grad_norm**norm_type
141+
total_norm_cuda = torch.tensor(
142+
[float(total_norm)], dtype=torch.float, device=device
143+
)
140144

141145
if data_parallel_group:
142-
torch.distributed.all_reduce(
146+
dist.all_reduce(
143147
total_norm_cuda,
144-
op=torch.distributed.ReduceOp.SUM,
148+
op=dist.ReduceOp.SUM,
145149
group=data_parallel_group,
146150
)
147-
torch.distributed.all_reduce(
151+
dist.all_reduce(
148152
total_norm_cuda,
149-
op=torch.distributed.ReduceOp.SUM,
153+
op=dist.ReduceOp.SUM,
150154
group=model_parallel_group,
151155
)
152156
total_norm = float(total_norm_cuda.item()) ** (1.0 / norm_type)
@@ -159,48 +163,68 @@ def clip_grad_by_total_norm_fp32(
159163
parameters: list[nn.Parameter],
160164
max_norm: int | float,
161165
total_norm: float,
162-
):
166+
) -> None:
167+
clip_coeff = max_norm / (total_norm + 1.0e-6)
168+
if clip_coeff >= 1.0:
169+
return
170+
163171
# dtype -> grad
164172
grads = defaultdict(list)
173+
cpu_grads = defaultdict(list)
165174
for param in parameters:
166175
if param.grad is not None:
167-
# For naive FSDP, lm_head has bf16 grad while others have fp32 grad
168176
grad = to_local_if_dtensor(param.grad).detach()
169-
grads[grad.dtype].append(grad)
177+
if grad.device.type != "cpu":
178+
grads[grad.dtype].append(grad)
179+
else:
180+
cpu_grads[grad.dtype].append(grad)
170181

171-
assert len(grads) > 0, len(grads)
172-
clip_coeff = max_norm / (total_norm + 1.0e-6)
173-
if clip_coeff < 1.0:
174-
for dtype, _grads in grads.items():
175-
dummy_overflow_buf = torch.tensor(
176-
[0], dtype=torch.int, device=current_platform.device_type
182+
if len(grads) == 0 and len(cpu_grads) == 0:
183+
return
184+
185+
from .multi_tensor_apply import (
186+
local_multi_tensor_applier,
187+
local_multi_tensor_scale,
188+
)
189+
190+
# Clip GPU grads
191+
for dtype, _grads in grads.items():
192+
dummy_overflow_buf = torch.tensor(
193+
[0], dtype=torch.int, device=current_platform.device_type
194+
)
195+
# For naive FSDP, lm_head has bf16 grad while others have fp32 grad
196+
if dtype == torch.float32:
197+
multi_tensor_applier(
198+
multi_tensor_scale_impl,
199+
dummy_overflow_buf,
200+
[_grads, _grads],
201+
clip_coeff,
202+
)
203+
else:
204+
local_multi_tensor_applier(
205+
local_multi_tensor_scale,
206+
dummy_overflow_buf,
207+
[_grads, _grads],
208+
clip_coeff,
177209
)
178-
if dtype == torch.float32:
179-
multi_tensor_applier(
180-
multi_tensor_scale_impl,
181-
dummy_overflow_buf,
182-
[_grads, _grads],
183-
clip_coeff,
184-
)
185-
else:
186-
from .multi_tensor_apply import (
187-
local_multi_tensor_applier,
188-
local_multi_tensor_scale,
189-
)
190210

191-
local_multi_tensor_applier(
192-
local_multi_tensor_scale,
193-
dummy_overflow_buf,
194-
[_grads, _grads],
195-
clip_coeff,
196-
)
211+
# Clip CPU grads
212+
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cpu")
213+
for _grads in cpu_grads.values():
214+
local_multi_tensor_applier(
215+
local_multi_tensor_scale,
216+
dummy_overflow_buf,
217+
[_grads, _grads],
218+
clip_coeff,
219+
)
197220

198221

199222
def fsdp2_clip_grad_norm(
200223
parameters: list[nn.Parameter],
201224
nd_device_mesh: DeviceMesh,
202225
max_norm: float,
203226
norm_type: float = 2.0,
227+
offload_params: bool = False,
204228
) -> float:
205229
if norm_type <= 0 and norm_type != float("inf"):
206230
raise ValueError(
@@ -215,7 +239,11 @@ def fsdp2_clip_grad_norm(
215239
grads_for_norm = get_main_grads_for_grad_norm(parameters, tensor_parallel_rank)
216240

217241
grad_norm = get_grad_norm_fp32(
218-
grads_for_norm, fsdp_group, tp_group, norm_type=norm_type
242+
grads_for_norm,
243+
fsdp_group,
244+
tp_group,
245+
norm_type=norm_type,
246+
offload_params=offload_params,
219247
)
220248

221249
if parameters:

0 commit comments

Comments
 (0)