Skip to content

Commit

Permalink
Update docs (custom policy, type hints) (#167)
Browse files Browse the repository at this point in the history
* Change import

* Update custom policy doc

* Re-enable sphinx_autodoc_typehints

* Update docker image

* Attempt to fix read the doc build error

* Add sphinx_autodoc_typehints to read the doc env

* Fix pip version

* Add full custom policy example

* Fix
  • Loading branch information
araffin authored Sep 29, 2020
1 parent 8b16324 commit 2c924f5
Show file tree
Hide file tree
Showing 17 changed files with 206 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.9.0a1
image: stablebaselines/stable-baselines3-cpu:0.9.0a2

type-check:
script:
Expand Down
5 changes: 3 additions & 2 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ channels:
- defaults
dependencies:
- cpuonly=1.0=0
- pip=20.0
- pip=20.2
- python=3.6
- pytorch=1.5.0=py3.6_cpu_0
- pip:
- gym==0.17.2
- gym>=0.17.2
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- matplotlib
- sphinx_autodoc_typehints
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __getattr__(cls, name):
# ones.
extensions = [
"sphinx.ext.autodoc",
# 'sphinx_autodoc_typehints',
"sphinx_autodoc_typehints",
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
Expand Down Expand Up @@ -128,7 +128,7 @@ def __getattr__(cls, name):


def setup(app):
app.add_stylesheet("css/baselines_theme.css")
app.add_css_file("css/baselines_theme.css")


# Theme options are theme-specific and customize the look and feel of a theme
Expand Down
167 changes: 164 additions & 3 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
.. _custom_policy:

Custom Policy Network
---------------------
=====================

Stable Baselines3 provides policy networks for images (CnnPolicies)
and other type of input features (MlpPolicies).


Custom Policy Architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^

One way of customising the policy network architecture is to pass arguments when creating the model,
using ``policy_kwargs`` parameter:

Expand Down Expand Up @@ -41,6 +45,68 @@ You can also easily define a custom architecture for the policy (or value) netwo
``policy_kwargs`` is particularly useful when doing hyperparameter search.


Custom Feature Extractor
^^^^^^^^^^^^^^^^^^^^^^^^

If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.

.. code-block:: python
import gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)
On-Policy Algorithms
^^^^^^^^^^^^^^^^^^^^

Shared Networks
---------------

The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
Expand Down Expand Up @@ -99,7 +165,102 @@ Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
action value
Advanced Example
~~~~~~~~~~~~~~~~

If your task requires even more granular control over the policy/value architecture, you can redefine the policy directly:

If your task requires even more granular control over the policy architecture, you can redefine the policy directly.

**TODO**
.. code-block:: python
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the feature extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""
def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super(CustomNetwork, self).__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.policy_net(features), self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
*args,
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = CustomNetwork(self.features_dim)
model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)
.. TODO (see https://github.com/DLR-RM/stable-baselines3/issues/113)
.. Off-Policy Algorithms
.. ^^^^^^^^^^^^^^^^^^^^^
..
.. If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
.. you can easily redefine the actor class for instance.
9 changes: 9 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ Breaking Changes:
- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and
``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params``
- Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity
- ``make_atari_env``, ``make_vec_env`` and ``set_random_seed`` must be imported with (and not directly from ``stable_baselines3.common``):

.. code-block:: python
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -47,6 +54,8 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
- Updated custom policy section (added custom feature extractor example)
- Re-enable ``sphinx_autodoc_typehints``



Expand Down
2 changes: 1 addition & 1 deletion docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Example
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
# For spelling
"sphinxcontrib.spelling",
# Type hints support
# 'sphinx-autodoc-typehints'
"sphinx-autodoc-typehints",
],
"extra": [
# For render
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
12 changes: 4 additions & 8 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import os
import typing
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import gym
import numpy as np

from stable_baselines3.common import logger
from stable_baselines3.common import base_class, logger # pytype: disable=pyi-error
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization

if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error


class BaseCallback(ABC):
"""
Expand All @@ -25,7 +21,7 @@ class BaseCallback(ABC):
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
# The RL model
self.model = None # type: Optional[BaseAlgorithm]
self.model = None # type: Optional[base_class.BaseAlgorithm]
# An alias for self.model.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
Expand All @@ -41,7 +37,7 @@ def __init__(self, verbose: int = 0):
self.parent = None # type: Optional[BaseCallback]

# Type hint as string to avoid circular import
def init_callback(self, model: "BaseAlgorithm") -> None:
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
Expand Down Expand Up @@ -137,7 +133,7 @@ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
if callback is not None:
self.callback.parent = self

def init_callback(self, model: "BaseAlgorithm") -> None:
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super(EventCallback, self).init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import gym
import torch as th
from gym import spaces
from torch import nn as nn
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal

from stable_baselines3.common.preprocessing import get_action_dim
Expand Down
7 changes: 2 additions & 5 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import typing
from typing import Callable, List, Optional, Tuple, Union

import gym
import numpy as np

from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv

if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm


def evaluate_policy(
model: "BaseAlgorithm",
model: "base_class.BaseAlgorithm",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
Expand Down
17 changes: 12 additions & 5 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import gym
import numpy as np
import torch as th
from torch import nn as nn
from torch import nn

from stable_baselines3.common.distributions import (
BernoulliDistribution,
Expand Down Expand Up @@ -439,17 +439,24 @@ def reset_noise(self, n_envs: int = 1) -> None:
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)

def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)

def _build(self, lr_schedule: Callable[[float], float]) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: (Callable) Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)
self._build_mlp_extractor()

latent_dim_pi = self.mlp_extractor.latent_dim_pi

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import gym
import torch as th
from torch import nn as nn
from torch import nn

from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.utils import get_device
Expand Down
Loading

0 comments on commit 2c924f5

Please sign in to comment.