-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Enable cpu offload with weights inside the module #2214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this annoying issue. Overall, I understand too little about the mechanism used by accelerate to control this, so I'd leave that part of the review to others.
This PR adds the possibility to perform cpu offload with the weights stored inside the module.
Sorry for being dense, but where exactly is that happening?
not sure about the naming of the arg as it can be confusing
Yes, I'd definitely rename the argument, especially since already have a cpu_offload function in the same file.
| Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole | ||
| module, or a dictionary mapping module name to boolean. | ||
| cpu_offload (`Union[bool, Dict[str, bool]]`, *optional*, defaults to `False`): | ||
| Whether the weights offloaded on the cpu should be kept in the module or not. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring misses to explain what the option is for passing a dict here.
| dispatch_model(model, device_map, offload_dir=tmp_dir, cpu_offload=True) | ||
|
|
||
| self.assertEqual(model.linear1.weight.device, torch.device("meta")) | ||
| self.assertEqual(model.batchnorm.weight.device, torch.device("cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new behavior of getting "cpu" here instead of "meta" looks more intuitive to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is what we are aiming in this PR ! We want to let the module on cpu and not on meta device.
| offload_dir: Optional[Union[str, os.PathLike]] = None, | ||
| offload_index: Optional[Dict[str, str]] = None, | ||
| offload_buffers: bool = False, | ||
| cpu_offload: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, should newly added parameters be placed last in case someone calls this function with purely positional arguments?
| ): | ||
| self.execution_device = execution_device | ||
| self.offload = offload | ||
| self.cpu_offload = cpu_offload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to pass both offload and cpu_offload? It seems the former would take precedent over the latter. Maybe this could be checked or documented?
|
|
||
| elif self.cpu_offload: | ||
| for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
| set_module_tensor_to_device(module, name, "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is special handling for Linear8bitLt required, similar to above?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
| offload_dir: Optional[Union[str, os.PathLike]] = None, | ||
| offload_index: Optional[Dict[str, str]] = None, | ||
| offload_buffers: bool = False, | ||
| cpu_offload: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be confusing as CPU offloading is already indicated in the device_map.
IMO ideally there should not be any argument added, and by default the weights of modules offloaded on RAM should be on cpu device, not meta. it this is kind of a breaking change in case anybody is assuming that by default attached weights are on meta and weights_map holds the true weights.
What does this PR do ?
This PR adds the possibility to perform cpu offload with the weights stored inside the module. You just need to pass
cpu_offload = Truein thedispatch_model. (not sure about the naming of the arg as it can be confusing)Before this PR, all offloaded modules were placed on the
metadevice and the weights were either stored in a dict (cpu offload) or a mmap (disk offload). We would then move the modules to the execution device with their respective value taken from the dict/mmap during theforward.For the user, this seems a little counter intuitive to put weights in a
dictin the cpu offload case. Moreover, letting these weights on the modules should not degrade the performance during inference at all. Offloading created a number of issues about the parameters being on themetadevice. While this does not completely solves issues related tometadevice, this should cover most cases users don't use disk offload that much.For now, the default value is
Falsebut I would like to make it the default behavior + extend it to Transformers. LMK if this make sense.cc @mfuntowicz since you had an issue with offloaded model + quantization
cc @LysandreJik for visibility
Solves partially 1190
TODO: