Pytorch and Jax Training Divergence #33687
Unanswered
martinoo31
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
thank you for reading my question. I have ported two models from Pytorch to JAX, one with real number and one with complex numbers, and while the training runs on the model with real parameters are similar between the two frameworks when using complex numbers I have completely different behaviours: the JAX version takes considerably longer to learn. All this while implementing the same architecture, using the same hyperparams, the same loss function. The training runs are executed on GSCDv2 with MFCC.
Testing the models side by side on dummy inputs shows that with batching the outputs are different (more than 1e-3) difference but the Mean Square Error (MSE) is quite low (e-13/e-15 depending on the precision used for matmuls).
Below are the notebooks with the models:
Side by side test
Pytorch Training (there are some synchronization problems between GPU and CPU as it is extremely slow)
Jax Training
This discussion mentions convolution layers as the cause of divergence, but in my case those are not present, and the performance issue is reversed in my case.
I would really appreciate if you could let me know if I missed something, and/or if you have an idea what could be causing the different training dynamics.
Thank you!
Edit:
I rewrote the code with manual tracking of complex operations (two matrices for each complex matrix), in this case it was trivial, and now the Jax version learns similarly to pytorch.
Beta Was this translation helpful? Give feedback.
All reactions