Skip to content

advice on how to reduce compile times? #24

@GallagherCommaJack

Description

@GallagherCommaJack

right now I have to wrap shampoo in optax.flatten or the tracing time on init is crazy long (284s in one case) and I don't know the compile time because I didn't wait long enough, but it's at least 10 minutes.

but wrapping in optax.flatten will reshard the entire parameter array every step, which is often going to be quite slow.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions