-
Notifications
You must be signed in to change notification settings - Fork 242
fix: fix CPU offloading in FSDP grad clipping and weight updates #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @rchardx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial enhancements to the FSDP engine, primarily focusing on robust support for CPU offloading. It ensures that distributed tensors (DTensors) that are offloaded to the CPU are correctly handled during the weight synchronization process, preventing potential issues with materialization. Furthermore, the PR optimizes memory usage during distributed weight updates by refactoring the parameter gathering mechanism into a batched approach. Finally, it extends the gradient norm calculation and clipping functionalities to seamlessly operate with gradients that reside on the CPU, thereby improving the overall stability and flexibility of the FSDP implementation. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces important fixes for FSDP CPU offloading. The changes correctly handle CPU-resident DTensors for weight synchronization and refactor weight gathering for better memory efficiency. The gradient norm calculation and clipping functions are also updated to support CPU-resident gradients. My review identified two critical bugs in areal/utils/fsdp/grad.py that will cause runtime errors due to incorrect tensor-to-scalar conversion. I have also noted a performance regression in the gradient norm calculation for non-offloaded gradients and provided suggestions for a fix. The other changes are well-implemented and align with the goals of the pull request.
043b496 to
233e70f
Compare
233e70f to
3088f14
Compare
Updates the gradient clipping implementation to correctly handle parameters offloaded to CPU, bypassing CUDA-specific optimizations when necessary to prevent runtime errors. Refactors the FSDP engine's weight broadcasting logic to properly materialize and batch DTensors in offloaded scenarios. Additionally, introduces a new test suite to verify gradient normalization and clipping behavior across different device configurations.
3088f14 to
2ceb1d2
Compare
fishcrap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
Updates the gradient clipping implementation to correctly handle parameters offloaded to CPU, bypassing CUDA-specific optimizations when necessary to prevent runtime errors. Refactors the FSDP engine's weight broadcasting logic to properly materialize and batch DTensors in offloaded scenarios. Additionally, introduces a new test suite to verify gradient normalization and clipping behavior across different device configurations.
Related Issue
Fixes #644.
In addition, this PR resolves the root cause of #677.
Type of Change
work as expected)
Checklist
jb build docs/gemini review)