Skip to content

Commit 5853f99

Browse files
committed
AdaMuon impl w/ a few other ideas based on recent reading
1 parent 4080bf8 commit 5853f99

File tree

3 files changed

+360
-26
lines changed

3 files changed

+360
-26
lines changed

tests/test_optim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,14 @@ def test_muon(optimizer):
402402
_test_model(optimizer, dict(lr=1e-3))
403403

404404

405+
@pytest.mark.parametrize('optimizer', ['adamuon', 'nadamuon'])
406+
def test_adamuon(optimizer):
407+
_test_rosenbrock(
408+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
409+
)
410+
_test_model(optimizer, dict(lr=1e-3))
411+
412+
405413
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
406414
def test_adopt(optimizer):
407415
_test_rosenbrock(

timm/optim/_optim_factory.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,24 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
897897
has_betas=True,
898898
defaults={'nesterov': True}
899899
),
900+
OptimInfo(
901+
name='adamuon',
902+
opt_class=Muon,
903+
description='AdaMuon: Muon with adaptive second moment estimation on orthogonalized directions',
904+
has_momentum=True,
905+
has_eps=True,
906+
has_betas=True,
907+
defaults={'algo': 'adamuon'}
908+
),
909+
OptimInfo(
910+
name='nadamuon',
911+
opt_class=Muon,
912+
description='AdaMuon with Nesterov momentum and NAdamW fallback for 1D params',
913+
has_momentum=True,
914+
has_eps=True,
915+
has_betas=True,
916+
defaults={'algo': 'adamuon', 'nesterov': True}
917+
),
900918
OptimInfo(
901919
name='novograd',
902920
opt_class=NvNovoGrad,

0 commit comments

Comments
 (0)