@@ -449,6 +449,44 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
449449
450450 fut .result ()
451451
452+ def _dtensor_to_full_tensor (self , dtensor : DTensor ) -> torch .Tensor :
453+ """Convert a DTensor to a full tensor, handling CPU offloaded tensors."""
454+ local_tensor = dtensor .to_local ()
455+ if local_tensor .device .type != "cpu" :
456+ return dtensor .full_tensor ()
457+
458+ device_mesh = dtensor .device_mesh
459+ placements = dtensor .placements
460+ temp_dtensor = DTensor .from_local (
461+ local_tensor ,
462+ device_mesh = device_mesh ,
463+ placements = placements ,
464+ )
465+ return temp_dtensor .full_tensor ()
466+
467+ def _materialize_and_update_bucket (
468+ self ,
469+ meta : WeightUpdateMeta ,
470+ named_params : list [tuple [str , nn .Parameter ]],
471+ ):
472+ """Materialize DTensors to full tensors and broadcast to inference engine."""
473+ main_rank : bool = dist .get_rank () == 0
474+ named_tensors = []
475+
476+ for name , param in named_params :
477+ if isinstance (param .data , DTensor ):
478+ tensor = self ._dtensor_to_full_tensor (param .data )
479+ else :
480+ tensor = param .data
481+ if tensor .device .type == "cpu" :
482+ tensor = tensor .to (current_platform .device_type )
483+
484+ if main_rank :
485+ named_tensors .append ((name , tensor ))
486+
487+ if named_tensors :
488+ self ._update_bucket_weights_from_distributed (meta , named_tensors )
489+
452490 @trace_perf ("fsdp_engine.update_weights_from_distributed" , category = "comm" )
453491 def _update_weights_from_distributed (self , meta : WeightUpdateMeta ):
454492 """Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
@@ -459,32 +497,33 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
459497 dist .barrier (group = self .cpu_group )
460498
461499 weight_chunked_mem_size = meta .weight_chunked_mem_mb * 1024 * 1024
500+ fsdp_world_size = dist .get_world_size (self .world_mesh ["dp_sp" ].get_group ())
462501
463502 buffer_size = 0
464- named_tensors = []
503+ named_params_bucket : list [ tuple [ str , nn . Parameter ]] = []
465504
466505 for name , param in self .get_model_name_parameters ():
467506 if isinstance (param .data , DTensor ):
468- tensor = param .data .full_tensor ()
507+ local_tensor = param .data .to_local ()
508+ tensor_size = local_tensor .numel () * local_tensor .element_size ()
509+ tensor_size *= fsdp_world_size
469510 else :
470- tensor = param .data
511+ tensor_size = param .data . numel () * param . data . element_size ()
471512
472- # Ranks other than 0 only help to get the full tensor
473- if dist .get_rank () != 0 :
474- continue
475-
476- tensor_size = tensor .numel () * tensor .element_size ()
477-
478- if tensor_size + buffer_size > weight_chunked_mem_size :
479- self ._update_bucket_weights_from_distributed (meta , named_tensors )
513+ if (
514+ tensor_size + buffer_size > weight_chunked_mem_size
515+ and named_params_bucket
516+ ):
517+ self ._materialize_and_update_bucket (meta , named_params_bucket )
518+ named_params_bucket = []
480519 buffer_size = 0
481520
482- named_tensors .append ((name , tensor ))
521+ named_params_bucket .append ((name , param ))
483522 buffer_size += tensor_size
484523
485- # Only rank-0 CAN contain named tensors here
486- if named_tensors :
487- self ._update_bucket_weights_from_distributed (meta , named_tensors )
524+ # Process remaining parameters
525+ if named_params_bucket :
526+ self ._materialize_and_update_bucket (meta , named_params_bucket )
488527
489528 dist .barrier (group = self .cpu_group )
490529
@@ -808,6 +847,7 @@ def train_batch(
808847 list (self .model .parameters ()),
809848 self .world_mesh ,
810849 max_norm = self .optimizer_config .gradient_clipping ,
850+ offload_params = self .config .fsdp .offload_params ,
811851 )
812852
813853 if not math .isfinite (grad_norm ):
0 commit comments