Skip to content

Commit 233e70f

Browse files
committed
fix: fix CPU offloading in FSDP grad clipping and weight updates
Updates the gradient clipping implementation to correctly handle parameters offloaded to CPU, bypassing CUDA-specific optimizations when necessary to prevent runtime errors. Refactors the FSDP engine's weight broadcasting logic to properly materialize and batch DTensors in offloaded scenarios. Additionally, introduces a new test suite to verify gradient normalization and clipping behavior across different device configurations.
1 parent 601afa7 commit 233e70f

File tree

3 files changed

+626
-46
lines changed

3 files changed

+626
-46
lines changed

areal/engine/fsdp_engine.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,44 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
449449

450450
fut.result()
451451

452+
def _dtensor_to_full_tensor(self, dtensor: DTensor) -> torch.Tensor:
453+
"""Convert a DTensor to a full tensor, handling CPU offloaded tensors."""
454+
local_tensor = dtensor.to_local()
455+
if local_tensor.device.type != "cpu":
456+
return dtensor.full_tensor()
457+
458+
device_mesh = dtensor.device_mesh
459+
placements = dtensor.placements
460+
temp_dtensor = DTensor.from_local(
461+
local_tensor,
462+
device_mesh=device_mesh,
463+
placements=placements,
464+
)
465+
return temp_dtensor.full_tensor()
466+
467+
def _materialize_and_update_bucket(
468+
self,
469+
meta: WeightUpdateMeta,
470+
named_params: list[tuple[str, nn.Parameter]],
471+
):
472+
"""Materialize DTensors to full tensors and broadcast to inference engine."""
473+
main_rank: bool = dist.get_rank() == 0
474+
named_tensors = []
475+
476+
for name, param in named_params:
477+
if isinstance(param.data, DTensor):
478+
tensor = self._dtensor_to_full_tensor(param.data)
479+
else:
480+
tensor = param.data
481+
if tensor.device.type == "cpu":
482+
tensor = tensor.to(current_platform.device_type)
483+
484+
if main_rank:
485+
named_tensors.append((name, tensor))
486+
487+
if named_tensors:
488+
self._update_bucket_weights_from_distributed(meta, named_tensors)
489+
452490
@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
453491
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
454492
"""Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
@@ -459,32 +497,33 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
459497
dist.barrier(group=self.cpu_group)
460498

461499
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
500+
fsdp_world_size = dist.get_world_size(self.world_mesh["dp_sp"].get_group())
462501

463502
buffer_size = 0
464-
named_tensors = []
503+
named_params_bucket: list[tuple[str, nn.Parameter]] = []
465504

466505
for name, param in self.get_model_name_parameters():
467506
if isinstance(param.data, DTensor):
468-
tensor = param.data.full_tensor()
507+
local_tensor = param.data.to_local()
508+
tensor_size = local_tensor.numel() * local_tensor.element_size()
509+
tensor_size *= fsdp_world_size
469510
else:
470-
tensor = param.data
511+
tensor_size = param.data.numel() * param.data.element_size()
471512

472-
# Ranks other than 0 only help to get the full tensor
473-
if dist.get_rank() != 0:
474-
continue
475-
476-
tensor_size = tensor.numel() * tensor.element_size()
477-
478-
if tensor_size + buffer_size > weight_chunked_mem_size:
479-
self._update_bucket_weights_from_distributed(meta, named_tensors)
513+
if (
514+
tensor_size + buffer_size > weight_chunked_mem_size
515+
and named_params_bucket
516+
):
517+
self._materialize_and_update_bucket(meta, named_params_bucket)
518+
named_params_bucket = []
480519
buffer_size = 0
481520

482-
named_tensors.append((name, tensor))
521+
named_params_bucket.append((name, param))
483522
buffer_size += tensor_size
484523

485-
# Only rank-0 CAN contain named tensors here
486-
if named_tensors:
487-
self._update_bucket_weights_from_distributed(meta, named_tensors)
524+
# Process remaining parameters
525+
if named_params_bucket:
526+
self._materialize_and_update_bucket(meta, named_params_bucket)
488527

489528
dist.barrier(group=self.cpu_group)
490529

@@ -808,6 +847,7 @@ def train_batch(
808847
list(self.model.parameters()),
809848
self.world_mesh,
810849
max_norm=self.optimizer_config.gradient_clipping,
850+
offload_params=self.config.fsdp.offload_params,
811851
)
812852

813853
if not math.isfinite(grad_norm):

0 commit comments

Comments
 (0)