Skip to content

Commit

Permalink
Merge pull request #318 from kozistr/feature/grams-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement `Grams` optimizer
  • Loading branch information
kozistr authored Jan 5, 2025
2 parents 8f538d4 + 66f0675 commit ddc81a7
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 36 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **87 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -195,6 +195,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
| Grams | *Grams: Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |

## Supported LR Scheduler

Expand Down
6 changes: 6 additions & 0 deletions docs/changelogs/v3.3.3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### Change Log

### Feature

* Implement `Grams` optimizer. (#317, #318)
* [Grams: Gradient Descent with Adaptive Momentum Scaling](https://arxiv.org/abs/2412.17107)
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **87 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -195,6 +195,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
| Grams | *Grams: Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@
:docstring:
:members:

::: pytorch_optimizer.Grams
:docstring:
:members:

::: pytorch_optimizer.Gravity
:docstring:
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_GaLore.png)

### Grams

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Grams.png)

### Gravity

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Gravity.png)
Expand Down Expand Up @@ -484,6 +488,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_GaLore.png)

### Grams

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Grams.png)

### Gravity

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Gravity.png)
Expand Down
Binary file added docs/visualizations/rastrigin_Grams.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_Grams.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 25 additions & 25 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate",
"Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero", "NovoGrad",
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast", "GSAM",
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW",
"SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE",
"BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
FAdam,
Fromage,
GaLore,
Grams,
Gravity,
GrokFastAdamW,
Kate,
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pytorch_optimizer.optimizer.ftrl import FTRL
from pytorch_optimizer.optimizer.galore import GaLore
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.grams import Grams
from pytorch_optimizer.optimizer.gravity import Gravity
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW
from pytorch_optimizer.optimizer.kate import Kate
Expand Down Expand Up @@ -282,6 +283,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
LaProp,
MARS,
SGDSaI,
Grams,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
108 changes: 108 additions & 0 deletions pytorch_optimizer/optimizer/grams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import math

import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class Grams(BaseOptimizer):
r"""Gradient Descent with Adaptive Momentum Scaling.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. decoupled weight decay.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
weight_decay: float = 0.0,
weight_decouple: bool = True,
eps: float = 1e-6,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'eps': eps,
}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'Grams'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.lerp_(grad, weight=beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

update = (exp_avg / bias_correction1) / (exp_avg_sq / bias_correction2_sq).sqrt_().add_(group['eps'])
update.abs_().mul_(grad.sign())

self.apply_weight_decay(
p,
grad,
lr=group['lr'],
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=False,
)

p.add_(update, alpha=-group['lr'])

return loss
Loading

0 comments on commit ddc81a7

Please sign in to comment.