right now I have to wrap shampoo in optax.flatten or the tracing time on init is crazy long (284s in one case) and I don't know the compile time because I didn't wait long enough, but it's at least 10 minutes.
but wrapping in optax.flatten will reshard the entire parameter array every step, which is often going to be quite slow.