diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 16df2ab1a..a8f2619b0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -34,6 +34,8 @@ Bug Fixes: - Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37) - Fixed approximate entropy calculation in PPO and A2C. (@andyshih12) - Fixed DQN target network sharing feature extractor with the main network. +- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12) +- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation. Deprecations: ^^^^^^^^^^^^^ @@ -49,6 +51,7 @@ Others: - Ignored errors from newer pytype version - Added a check when using ``gSDE`` - Removed codacy dependency from Dockerfile +- Added ``common.sb2_compat.RMSpropTFLike`` optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow. Documentation: ^^^^^^^^^^^^^^ diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 141fa0bb1..460d1a6e3 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -10,6 +10,14 @@ A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3 It uses multiple workers to avoid the use of a replay buffer. +.. warning:: + + If you find training unstable or want to match performance of stable-baselines A2C, consider using + ``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``. + You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``. + Read more `here `_. + + Notes ----- diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index cc1d0782e..dc32476bf 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -116,6 +116,7 @@ def train(self) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) + # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 1ba7cc0ce..a3de8cef8 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -123,6 +123,7 @@ def __init__( self.tensorboard_log = tensorboard_log self.lr_schedule = None # type: Optional[Callable] self._last_obs = None # type: Optional[np.ndarray] + self._last_dones = None # type: Optional[np.ndarray] # When using VecNormalize: self._last_original_obs = None # type: Optional[np.ndarray] self._episode_num = 0 @@ -474,6 +475,7 @@ def _setup_learn( # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: self._last_obs = self.env.reset() + self._last_dones = np.zeros((self.env.num_envs,), dtype=np.bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: self._last_original_obs = self._vec_normalize_env.get_original_obs() diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index f84d18f34..8c9cb8b7e 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -173,8 +173,9 @@ def collect_rollouts( if isinstance(self.action_space, gym.spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) - rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs) + rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs) self._last_obs = new_obs + self._last_dones = dones rollout_buffer.compute_returns_and_advantage(values, dones=dones) diff --git a/stable_baselines3/common/sb2_compat/__init__.py b/stable_baselines3/common/sb2_compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py new file mode 100644 index 000000000..46ef4b06f --- /dev/null +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -0,0 +1,126 @@ +import torch +from torch.optim import Optimizer + + +class RMSpropTFLike(Optimizer): + r"""Implements RMSprop algorithm with closer match to Tensorflow version. + + For reproducibility with original stable-baselines. Use this + version with e.g. A2C for stabler learning than with the PyTorch + RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop. + + See a more throughout conversion in pytorch-image-models repository: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py + + Changes to the original RMSprop: + - Move epsilon inside square root + - Initialize squared gradient to ones rather than zeros + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + The implementation here takes the square root of the gradient average before + adding epsilon (note that TensorFlow interchanges these two operations). The effective + learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` + is the scheduled learning rate and :math:`v` is the weighted moving average + of the squared gradient. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing constant (default: 0.99) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + """ + + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) + super(RMSpropTFLike, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTFLike, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("momentum", 0) + group.setdefault("centered", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("RMSpropTF does not support sparse gradients") + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # PyTorch initialized to zeros here + state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format) + if group["momentum"] > 0: + state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if group["centered"]: + state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + square_avg = state["square_avg"] + alpha = group["alpha"] + + state["step"] += 1 + + if group["weight_decay"] != 0: + grad = grad.add(p, alpha=group["weight_decay"]) + + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) + + if group["centered"]: + grad_avg = state["grad_avg"] + grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) + # PyTorch added epsilon after square root + # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_() + else: + # PyTorch added epsilon after square root + # avg = square_avg.sqrt().add_(group['eps']) + avg = square_avg.add(group["eps"]).sqrt_() + + if group["momentum"] > 0: + buf = state["momentum_buffer"] + buf.mul_(group["momentum"]).addcdiv_(grad, avg) + p.add_(buf, alpha=-group["lr"]) + else: + p.addcdiv_(grad, avg, value=-group["lr"]) + + return loss diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 9c74017cb..9429a86eb 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.ReLU(), - nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Flatten(), ) diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index cd379126c..c1e08dfac 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -2,6 +2,7 @@ import torch as th from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike @pytest.mark.parametrize( @@ -32,3 +33,8 @@ def test_custom_offpolicy(model_class, net_arch): def test_custom_optimizer(model_class, optimizer_kwargs): policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000) + + +def test_tf_like_rmsprop_optimizer(): + policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) + _ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)