diff --git a/pytorch_optimizer/optimizer/lookahead.py b/pytorch_optimizer/optimizer/lookahead.py index 507971f2c..6e149f0f4 100644 --- a/pytorch_optimizer/optimizer/lookahead.py +++ b/pytorch_optimizer/optimizer/lookahead.py @@ -36,7 +36,6 @@ def __init__( self.pullback_momentum = pullback_momentum self.optimizer = optimizer - self.param_groups = self.optimizer.param_groups self.state: STATE = defaultdict(dict) @@ -58,6 +57,10 @@ def __init__( **optimizer.defaults, } + @property + def param_groups(self): + return self.optimizer.param_groups + def __getstate__(self): return { 'state': self.state,