diff --git a/docs/changelogs/v3.0.1.md b/docs/changelogs/v3.0.1.md index 41bd9e094..032f193ab 100644 --- a/docs/changelogs/v3.0.1.md +++ b/docs/changelogs/v3.0.1.md @@ -4,6 +4,10 @@ * Implement `FAdam` optimizer. (#241, #242) * [Adam is a natural gradient optimizer using diagonal empirical Fisher information](https://arxiv.org/abs/2405.12807) +* Tweak `AdaFactor` optimizer. (#236, #243) + * support not-using-first-momentum when beta1 is not given + * default dtype for first momentum to `bfloat16` + * clip second momentum to 0.999 ### Bug diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 4e0cdb64b..722081a59 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -215,7 +215,7 @@ def get_adanorm_gradient( return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad @staticmethod - def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)'): + def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None: if range_type == '[)' and not low <= x < high: raise ValueError(f'[-] {name} must be in the range [{low}, {high})') if range_type == '[]' and not low <= x <= high: @@ -226,40 +226,42 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str raise ValueError(f'[-] {name} must be in the range ({low}, {high})') @staticmethod - def validate_non_negative(x: Optional[float], name: str): + def validate_non_negative(x: Optional[float], name: str) -> None: if x is not None and x < 0.0: raise ValueError(f'[-] {name} must be non-negative') @staticmethod - def validate_positive(x: Union[float, int], name: str): + def validate_positive(x: Union[float, int], name: str) -> None: if x <= 0: raise ValueError(f'[-] {name} must be positive') @staticmethod - def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper'): + def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None: if bound_type == 'upper' and constant > boundary: raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]') if bound_type == 'lower' and constant < boundary: raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)') @staticmethod - def validate_step(step: int, step_type: str): + def validate_step(step: int, step_type: str) -> None: if step < 1: raise NegativeStepError(step, step_type=step_type) @staticmethod - def validate_options(x: str, name: str, options: List[str]): + def validate_options(x: str, name: str, options: List[str]) -> None: if x not in options: opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip() raise ValueError(f'[-] {name} {x} must be one of ({opts})') @staticmethod - def validate_learning_rate(learning_rate: Optional[float]): + def validate_learning_rate(learning_rate: Optional[float]) -> None: if learning_rate is not None and learning_rate < 0.0: raise NegativeLRError(learning_rate) - def validate_betas(self, betas: BETAS): - self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]') + def validate_betas(self, betas: BETAS) -> None: + if betas[0] is not None: + self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]') + self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]') if len(betas) < 3: @@ -268,7 +270,7 @@ def validate_betas(self, betas: BETAS): if betas[2] is not None: self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]') - def validate_nus(self, nus: Union[float, Tuple[float, float]]): + def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None: if isinstance(nus, float): self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]') else: diff --git a/pytorch_optimizer/base/types.py b/pytorch_optimizer/base/types.py index e6308a45f..5d4556996 100644 --- a/pytorch_optimizer/base/types.py +++ b/pytorch_optimizer/base/types.py @@ -6,7 +6,7 @@ CLOSURE = Optional[Callable[[], float]] LOSS = Optional[float] -BETAS = Union[Tuple[float, float], Tuple[float, float, float]] +BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]] DEFAULTS = Dict PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]] STATE = Dict diff --git a/pytorch_optimizer/optimizer/adafactor.py b/pytorch_optimizer/optimizer/adafactor.py index 2945fd1d4..cc7b0fa35 100644 --- a/pytorch_optimizer/optimizer/adafactor.py +++ b/pytorch_optimizer/optimizer/adafactor.py @@ -10,11 +10,12 @@ class AdaFactor(Optimizer, BaseOptimizer): - r"""Adaptive Learning Rates with Sublinear Memory Cost. + r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. :param lr: float. learning rate. - :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared + hessian trace. if beta1 is None, first momentum will be skipped. :param decay_rate: float. coefficient used to compute running averages of square gradient. :param weight_decay: float. weight decay (L2 penalty). :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. @@ -27,6 +28,9 @@ class AdaFactor(Optimizer, BaseOptimizer): is being used. :param eps1: float. term added to the denominator to improve numerical stability. :param eps2: float. term added to the denominator to improve numerical stability. + :param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in + half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while + reducing optimize overhead from 2-fold to 1.5-fold. """ def __init__( @@ -45,6 +49,7 @@ def __init__( warmup_init: bool = False, eps1: float = 1e-30, eps2: float = 1e-3, + momentum_dtype: torch.dtype = torch.bfloat16, ): self.validate_learning_rate(lr) self.validate_betas(betas) @@ -56,6 +61,7 @@ def __init__( self.clip_threshold = clip_threshold self.eps1 = eps1 self.eps2 = eps2 + self.momentum_dtype = momentum_dtype defaults: DEFAULTS = { 'lr': lr, @@ -87,7 +93,8 @@ def reset(self): grad_shape: Tuple[int, ...] = grad.shape factored: bool = self.get_options(grad_shape) - state['exp_avg'] = torch.zeros_like(p) + if group['betas'][0] is not None: + state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype) if factored: state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device) @@ -149,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: else: group['step'] = 1 - beta1, _ = group['betas'] + beta1, beta2 = group['betas'] beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate) @@ -167,7 +174,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: factored: bool = self.get_options(grad_shape) if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) + if beta1 is not None: + state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype) if factored: state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device) @@ -205,6 +213,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: else: exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t) + exp_avg_sq.clamp_(max=beta2) + torch.rsqrt(exp_avg_sq, out=update) if group['ams_bound']: @@ -216,8 +226,11 @@ def step(self, closure: CLOSURE = None) -> LOSS: update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr) - exp_avg = state['exp_avg'] - exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1) + if beta1 is not None: + exp_avg = state['exp_avg'] + exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1) + + update = exp_avg self.apply_weight_decay( p=p, @@ -228,6 +241,6 @@ def step(self, closure: CLOSURE = None) -> LOSS: fixed_decay=group['fixed_decay'], ) - p.add_(-exp_avg) + p.add_(-update) return loss diff --git a/tests/constants.py b/tests/constants.py index 34f852146..eee2ecb1c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -348,8 +348,9 @@ (DAdaptLion, {'lr': 3e0, 'weight_decay': 1e-3}, 20), (AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10), (AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20), - (AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100), - (AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 125), + (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100), + (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120), + (AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),