Skip to content

Commit 5998f86

Browse files
authored
refactor: nit change for get_parameters_from_modules (code debt) (#3815)
* refactor: nit change for get_parameters_from_modules Signed-off-by: Mehant Kammakomati <[email protected]> * fix: quality check Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent f0313a6 commit 5998f86

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -630,11 +630,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
630630
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
631631
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
632632
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
633+
"ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
633634
}
634-
if fsdp2_plugin.ignored_modules is not None:
635-
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
636-
fsdp2_plugin.ignored_modules, model, accelerator.device
637-
)
638635

639636
model_has_params4bit = False
640637
for name, param in model.named_parameters():
@@ -808,10 +805,10 @@ def get_parameters_from_modules(
808805
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
809806
810807
Returns:
811-
`List[torch.nn.Parameter]`: List of parameters
808+
`set[torch.nn.Parameter]`: List of parameters
812809
"""
813810
if modules is None:
814-
return None
811+
return set()
815812
parameters = []
816813
# code taken from accelerate while preparing kwargs for FSDP
817814
if isinstance(modules, str):

0 commit comments

Comments
 (0)