Skip to content

Commit

Permalink
Merge pull request #243 from kozistr/update/adafactor-optimizer
Browse files Browse the repository at this point in the history
[Feature] Tweak AdaFactor optimizer
  • Loading branch information
kozistr authored Jun 8, 2024
2 parents 4a095ae + 5d924c5 commit 15d52f6
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 21 deletions.
4 changes: 4 additions & 0 deletions docs/changelogs/v3.0.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

* Implement `FAdam` optimizer. (#241, #242)
* [Adam is a natural gradient optimizer using diagonal empirical Fisher information](https://arxiv.org/abs/2405.12807)
* Tweak `AdaFactor` optimizer. (#236, #243)
* support not-using-first-momentum when beta1 is not given
* default dtype for first momentum to `bfloat16`
* clip second momentum to 0.999

### Bug

Expand Down
22 changes: 12 additions & 10 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_adanorm_gradient(
return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad

@staticmethod
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)'):
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
if range_type == '[)' and not low <= x < high:
raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
if range_type == '[]' and not low <= x <= high:
Expand All @@ -226,40 +226,42 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')

@staticmethod
def validate_non_negative(x: Optional[float], name: str):
def validate_non_negative(x: Optional[float], name: str) -> None:
if x is not None and x < 0.0:
raise ValueError(f'[-] {name} must be non-negative')

@staticmethod
def validate_positive(x: Union[float, int], name: str):
def validate_positive(x: Union[float, int], name: str) -> None:
if x <= 0:
raise ValueError(f'[-] {name} must be positive')

@staticmethod
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper'):
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
if bound_type == 'upper' and constant > boundary:
raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
if bound_type == 'lower' and constant < boundary:
raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')

@staticmethod
def validate_step(step: int, step_type: str):
def validate_step(step: int, step_type: str) -> None:
if step < 1:
raise NegativeStepError(step, step_type=step_type)

@staticmethod
def validate_options(x: str, name: str, options: List[str]):
def validate_options(x: str, name: str, options: List[str]) -> None:
if x not in options:
opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip()
raise ValueError(f'[-] {name} {x} must be one of ({opts})')

@staticmethod
def validate_learning_rate(learning_rate: Optional[float]):
def validate_learning_rate(learning_rate: Optional[float]) -> None:
if learning_rate is not None and learning_rate < 0.0:
raise NegativeLRError(learning_rate)

def validate_betas(self, betas: BETAS):
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
def validate_betas(self, betas: BETAS) -> None:
if betas[0] is not None:
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')

self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]')

if len(betas) < 3:
Expand All @@ -268,7 +270,7 @@ def validate_betas(self, betas: BETAS):
if betas[2] is not None:
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')

def validate_nus(self, nus: Union[float, Tuple[float, float]]):
def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
if isinstance(nus, float):
self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')
else:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/base/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

CLOSURE = Optional[Callable[[], float]]
LOSS = Optional[float]
BETAS = Union[Tuple[float, float], Tuple[float, float, float]]
BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]]
DEFAULTS = Dict
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
STATE = Dict
Expand Down
29 changes: 21 additions & 8 deletions pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@


class AdaFactor(Optimizer, BaseOptimizer):
r"""Adaptive Learning Rates with Sublinear Memory Cost.
r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared
hessian trace. if beta1 is None, first momentum will be skipped.
:param decay_rate: float. coefficient used to compute running averages of square gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
Expand All @@ -27,6 +28,9 @@ class AdaFactor(Optimizer, BaseOptimizer):
is being used.
:param eps1: float. term added to the denominator to improve numerical stability.
:param eps2: float. term added to the denominator to improve numerical stability.
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
reducing optimize overhead from 2-fold to 1.5-fold.
"""

def __init__(
Expand All @@ -45,6 +49,7 @@ def __init__(
warmup_init: bool = False,
eps1: float = 1e-30,
eps2: float = 1e-3,
momentum_dtype: torch.dtype = torch.bfloat16,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
Expand All @@ -56,6 +61,7 @@ def __init__(
self.clip_threshold = clip_threshold
self.eps1 = eps1
self.eps2 = eps2
self.momentum_dtype = momentum_dtype

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -87,7 +93,8 @@ def reset(self):
grad_shape: Tuple[int, ...] = grad.shape
factored: bool = self.get_options(grad_shape)

state['exp_avg'] = torch.zeros_like(p)
if group['betas'][0] is not None:
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)

if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
Expand Down Expand Up @@ -149,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
else:
group['step'] = 1

beta1, _ = group['betas']
beta1, beta2 = group['betas']

beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate)

Expand All @@ -167,7 +174,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
factored: bool = self.get_options(grad_shape)

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
if beta1 is not None:
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)

if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
Expand Down Expand Up @@ -205,6 +213,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
else:
exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
exp_avg_sq.clamp_(max=beta2)

torch.rsqrt(exp_avg_sq, out=update)

if group['ams_bound']:
Expand All @@ -216,8 +226,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:

update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)

exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
if beta1 is not None:
exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

update = exp_avg

self.apply_weight_decay(
p=p,
Expand All @@ -228,6 +241,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
fixed_decay=group['fixed_decay'],
)

p.add_(-exp_avg)
p.add_(-update)

return loss
5 changes: 3 additions & 2 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,9 @@
(DAdaptLion, {'lr': 3e0, 'weight_decay': 1e-3}, 20),
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 125),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),
Expand Down

0 comments on commit 15d52f6

Please sign in to comment.