Skip to content

Commit

Permalink
Merge pull request #312 from kozistr/feature/apollo-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement APOLLO optimizer
  • Loading branch information
kozistr authored Dec 15, 2024
2 parents 7bb85f9 + b40e8d4 commit 38bafd8
Show file tree
Hide file tree
Showing 14 changed files with 396 additions and 174 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |

## Supported LR Scheduler

Expand Down
3 changes: 3 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

* Support `Cautious` variant to `AdaShift` optimizer. (#310)
* Save the state of the `Lookahead` optimizer too. (#310)
* Implement `APOLLO` optimizer. (#311, #312)
* [SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)
* Rename the `Apollo` (`An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization`) optimizer name to `ApolloDQN` not to overlap with the new optimizer name `APOLLO`. (#312)

### Bug

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |

## Supported LR Scheduler

Expand Down
6 changes: 5 additions & 1 deletion docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@
:docstring:
:members:

::: pytorch_optimizer.Apollo
::: pytorch_optimizer.APOLLO
:docstring:
:members:

::: pytorch_optimizer.ApolloDQN
:docstring:
:members:

Expand Down
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp",
"LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID",
"PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate",
"Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam",
"PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW",
"SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE",
"BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from pytorch_optimizer.optimizer import (
ADOPT,
APOLLO,
ASGD,
BSAM,
CAME,
Expand Down Expand Up @@ -90,7 +91,7 @@
Aida,
AliG,
Amos,
Apollo,
ApolloDQN,
AvaGrad,
DAdaptAdaGrad,
DAdaptAdam,
Expand Down
5 changes: 3 additions & 2 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_optimizer.optimizer.aida import Aida
from pytorch_optimizer.optimizer.alig import AliG
from pytorch_optimizer.optimizer.amos import Amos
from pytorch_optimizer.optimizer.apollo import Apollo
from pytorch_optimizer.optimizer.apollo import APOLLO, ApolloDQN
from pytorch_optimizer.optimizer.avagrad import AvaGrad
from pytorch_optimizer.optimizer.came import CAME
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD
Expand Down Expand Up @@ -228,7 +228,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
DAdaptAdan,
AdamS,
AdaFactor,
Apollo,
ApolloDQN,
APOLLO,
SWATS,
NovoGrad,
Lion,
Expand Down
168 changes: 162 additions & 6 deletions pytorch_optimizer/optimizer/apollo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Optional
import math
from typing import Literal, Optional

import numpy as np
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector

SCALE_TYPE = Literal['channel', 'tensor']

class Apollo(BaseOptimizer):

class ApolloDQN(BaseOptimizer):
r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand All @@ -25,8 +29,8 @@ class Apollo(BaseOptimizer):
def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
init_lr: Optional[float] = None,
lr: float = 1e-2,
init_lr: Optional[float] = 1e-5,
beta: float = 0.9,
rebound: str = 'constant',
weight_decay: float = 0.0,
Expand Down Expand Up @@ -58,7 +62,7 @@ def __init__(
super().__init__(params, defaults)

def __str__(self) -> str:
return 'Apollo'
return 'ApolloDQN'

@torch.no_grad()
def reset(self):
Expand Down Expand Up @@ -146,3 +150,155 @@ def step(self, closure: CLOSURE = None) -> LOSS:
p.add_(d_p, alpha=-current_lr)

return loss


class APOLLO(BaseOptimizer):
r"""SGD-like Memory, AdamW-level Performance.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param correct_bias: bool. Whether to correct bias in Adam.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-2,
betas: BETAS = (0.9, 0.999),
scale_type: SCALE_TYPE = 'tensor',
weight_decay: float = 0.0,
weight_decouple: bool = True,
fixed_decay: bool = False,
correct_bias: bool = True,
eps: float = 1e-6,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'scale_type': scale_type,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'correct_bias': correct_bias,
'eps': eps,
**kwargs,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'APOLLO'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

step_size: float = group['lr']
if group['correct_bias']:
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
step_size *= bias_correction2_sq / bias_correction1

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

if 'rank' in group and p.dim() > 1:
if 'projector' not in state:
state['projector'] = GaLoreProjector(
rank=group['rank'],
update_proj_gap=group['update_proj_gap'],
scale=group['scale'],
projection_type=group['projection_type'],
)

grad = state['projector'].project(grad, group['step'], from_random_matrix=True)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().add_(group['eps'])

norm_grad = exp_avg / de_nom
if 'rank' in group and p.dim() > 1:
if group['scale_type'] == 'channel':
norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
if norm_dim == 1:
scaling_factor = scaling_factor.unsqueeze(1)
else:
scaling_factor = torch.norm(norm_grad) / (torch.norm(grad) + 1e-8)

scaling_grad = grad * scaling_factor

scaling_grad_norm = torch.norm(scaling_grad)
if 'scaling_grad' in state:
limiter = (
max(
scaling_grad_norm / (state['scaling_grad'] + 1e-8),
1.01,
)
/ 1.01
)

scaling_grad.div_(limiter)
scaling_grad_norm.div_(limiter)

state['scaling_grad'] = scaling_grad_norm

norm_grad = scaling_grad * np.sqrt(group['scale'])
norm_grad = state['projector'].project_back(norm_grad)

p.add_(norm_grad, alpha=-step_size)

self.apply_weight_decay(
p,
grad,
lr=step_size,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)

return loss
Loading

0 comments on commit 38bafd8

Please sign in to comment.