-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #318 from kozistr/feature/grams-optimizer
[Feature] Implement `Grams` optimizer
- Loading branch information
Showing
16 changed files
with
169 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
### Change Log | ||
|
||
### Feature | ||
|
||
* Implement `Grams` optimizer. (#317, #318) | ||
* [Grams: Gradient Descent with Adaptive Momentum Scaling](https://arxiv.org/abs/2412.17107) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,7 @@ | |
FAdam, | ||
Fromage, | ||
GaLore, | ||
Grams, | ||
Gravity, | ||
GrokFastAdamW, | ||
Kate, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import math | ||
|
||
import torch | ||
|
||
from pytorch_optimizer.base.exception import NoSparseGradientError | ||
from pytorch_optimizer.base.optimizer import BaseOptimizer | ||
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS | ||
|
||
|
||
class Grams(BaseOptimizer): | ||
r"""Gradient Descent with Adaptive Momentum Scaling. | ||
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. | ||
:param lr: float. learning rate. | ||
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. | ||
:param weight_decay: float. weight decay (L2 penalty). | ||
:param weight_decouple: bool. decoupled weight decay. | ||
:param eps: float. term added to the denominator to improve numerical stability. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
params: PARAMETERS, | ||
lr: float = 1e-3, | ||
betas: BETAS = (0.9, 0.999), | ||
weight_decay: float = 0.0, | ||
weight_decouple: bool = True, | ||
eps: float = 1e-6, | ||
**kwargs, | ||
): | ||
self.validate_learning_rate(lr) | ||
self.validate_betas(betas) | ||
self.validate_non_negative(weight_decay, 'weight_decay') | ||
self.validate_non_negative(eps, 'eps') | ||
|
||
defaults: DEFAULTS = { | ||
'lr': lr, | ||
'betas': betas, | ||
'weight_decay': weight_decay, | ||
'weight_decouple': weight_decouple, | ||
'eps': eps, | ||
} | ||
|
||
super().__init__(params, defaults) | ||
|
||
def __str__(self) -> str: | ||
return 'Grams' | ||
|
||
@torch.no_grad() | ||
def reset(self): | ||
for group in self.param_groups: | ||
group['step'] = 0 | ||
for p in group['params']: | ||
state = self.state[p] | ||
|
||
state['exp_avg'] = torch.zeros_like(p) | ||
state['exp_avg_sq'] = torch.zeros_like(p) | ||
|
||
@torch.no_grad() | ||
def step(self, closure: CLOSURE = None) -> LOSS: | ||
loss: LOSS = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
if 'step' in group: | ||
group['step'] += 1 | ||
else: | ||
group['step'] = 1 | ||
|
||
beta1, beta2 = group['betas'] | ||
|
||
bias_correction1: float = self.debias(beta1, group['step']) | ||
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
|
||
grad = p.grad | ||
if grad.is_sparse: | ||
raise NoSparseGradientError(str(self)) | ||
|
||
state = self.state[p] | ||
if len(state) == 0: | ||
state['exp_avg'] = torch.zeros_like(p) | ||
state['exp_avg_sq'] = torch.zeros_like(p) | ||
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||
exp_avg.lerp_(grad, weight=beta1) | ||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | ||
|
||
update = (exp_avg / bias_correction1) / (exp_avg_sq / bias_correction2_sq).sqrt_().add_(group['eps']) | ||
update.abs_().mul_(grad.sign()) | ||
|
||
self.apply_weight_decay( | ||
p, | ||
grad, | ||
lr=group['lr'], | ||
weight_decay=group['weight_decay'], | ||
weight_decouple=group['weight_decouple'], | ||
fixed_decay=False, | ||
) | ||
|
||
p.add_(update, alpha=-group['lr']) | ||
|
||
return loss |
Oops, something went wrong.