diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e04e59ee4..6fa1a2c17 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,9 +50,9 @@ def my_function(arg1: type1, arg2: type2) -> returntype: """ Short description of the function. - :param arg1: (type1) describe what is arg1 - :param arg2: (type2) describe what is arg2 - :return: (returntype) describe what is returned + :param arg1: describe what is arg1 + :param arg2: describe what is arg2 + :return: describe what is returned """ ... return my_variable diff --git a/docs/_static/css/baselines_theme.css b/docs/_static/css/baselines_theme.css index 89455aa88..450864efe 100644 --- a/docs/_static/css/baselines_theme.css +++ b/docs/_static/css/baselines_theme.css @@ -50,3 +50,12 @@ a.icon.icon-home { .codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { background: #f8f8f8;; } + +/* Change style of types in the docstrings .rst-content .field-list */ +.field-list .xref.py.docutils, .field-list code.docutils, .field-list .docutils.literal.notranslate +{ + border: None; + padding-left: 0; + padding-right: 0; + color: #404040; +} diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 11122556d..d74b40235 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -56,6 +56,7 @@ Documentation: - Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio) - Updated custom policy section (added custom feature extractor example) - Re-enable ``sphinx_autodoc_typehints`` +- Updated doc style for type hints and remove duplicated type hints diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index dc32476bf..ee25ee2f6 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -21,34 +21,34 @@ class A2C(OnPolicyAlgorithm): Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752 - :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) The learning rate, it can be a function - :param n_steps: (int) The number of steps to run for each environment per update + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param gamma: (float) Discount factor - :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. - :param ent_coef: (float) Entropy coefficient for the loss calculation - :param vf_coef: (float) Value function coefficient for the loss calculation - :param max_grad_norm: (float) The maximum value for the gradient clipping - :param rms_prop_eps: (float) RMSProp epsilon. It stabilizes square root computation in denominator + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator of RMSProp update - :param use_rms_prop: (bool) Whether to use RMSprop (default) or Adam as optimizer - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param normalize_advantage: (bool) Whether to normalize or not the advantage - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param normalize_advantage: Whether to normalize or not the advantage + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index fdae2d327..977f6d32c 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -18,8 +18,8 @@ def __init__(self, env: gym.Env, noop_max: int = 30): Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. - :param env: (gym.Env) the environment to wrap - :param noop_max: (int) the maximum value of no-ops to run + :param env: the environment to wrap + :param noop_max: the maximum value of no-ops to run """ gym.Wrapper.__init__(self, env) self.noop_max = noop_max @@ -47,7 +47,7 @@ def __init__(self, env: gym.Env): """ Take action on reset for environments that are fixed until firing. - :param env: (gym.Env) the environment to wrap + :param env: the environment to wrap """ gym.Wrapper.__init__(self, env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" @@ -70,7 +70,7 @@ def __init__(self, env: gym.Env): Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. - :param env: (gym.Env) the environment to wrap + :param env: the environment to wrap """ gym.Wrapper.__init__(self, env) self.lives = 0 @@ -97,7 +97,7 @@ def reset(self, **kwargs) -> np.ndarray: and the learner need not know about any of this behind-the-scenes. :param kwargs: Extra keywords passed to env.reset() call - :return: (np.ndarray) the first observation of the environment + :return: the first observation of the environment """ if self.was_real_done: obs = self.env.reset(**kwargs) @@ -113,8 +113,8 @@ def __init__(self, env: gym.Env, skip: int = 4): """ Return only every ``skip``-th frame (frameskipping) - :param env: (gym.Env) the environment - :param skip: (int) number of ``skip``-th frame + :param env: the environment + :param skip: number of ``skip``-th frame """ gym.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) @@ -126,8 +126,8 @@ def step(self, action: int) -> GymStepReturn: Step the environment with the given action Repeat action, sum reward, and max over last observations. - :param action: ([int] or [float]) the action - :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + :param action: the action + :return: observation, reward, done, information """ total_reward = 0.0 done = None @@ -155,7 +155,7 @@ def __init__(self, env: gym.Env): """ Clips the reward to {+1, 0, -1} by its sign. - :param env: (gym.Env) the environment + :param env: the environment """ gym.RewardWrapper.__init__(self, env) @@ -163,8 +163,8 @@ def reward(self, reward: float) -> float: """ Bin reward to {+1, 0, -1} by its sign. - :param reward: (float) - :return: (float) + :param reward: + :return: """ return np.sign(reward) @@ -175,9 +175,9 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84): Convert to grayscale and warp frames to 84x84 (default) as done in the Nature paper and later work. - :param env: (gym.Env) the environment - :param width: (int) - :param height: (int) + :param env: the environment + :param width: + :param height: """ gym.ObservationWrapper.__init__(self, env) self.width = width @@ -190,8 +190,8 @@ def observation(self, frame: np.ndarray) -> np.ndarray: """ returns the current observation from a frame - :param frame: (np.ndarray) environment frame - :return: (np.ndarray) the observation + :param frame: environment frame + :return: the observation """ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) @@ -212,13 +212,13 @@ class AtariWrapper(gym.Wrapper): * Grayscale observation * Clip reward to {-1, 0, 1} - :param env: (gym.Env) gym environment - :param noop_max: (int): max number of no-ops - :param frame_skip: (int): the frequency at which the agent experiences the game. - :param screen_size: (int): resize Atari frame - :param terminal_on_life_loss: (bool): if True, then step() returns done=True whenever a + :param env: gym environment + :param noop_max:: max number of no-ops + :param frame_skip:: the frequency at which the agent experiences the game. + :param screen_size:: resize Atari frame + :param terminal_on_life_loss:: if True, then step() returns done=True whenever a life is lost. - :param clip_reward: (bool) If True (default), the reward is clip to {-1, 0, 1} depending on its sign. + :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. """ def __init__( diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 2015f0593..d30830bb3 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -32,9 +32,9 @@ def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]: """If env is a string, make the environment; otherwise, return env. - :param env: (Union[GymEnv, str, None]) The environment to learn from. - :param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env. - :param verbose: (int) logging verbosity + :param env: The environment to learn from. + :param monitor_wrapper: Whether to wrap env in a Monitor when creating env. + :param verbose: logging verbosity :return A Gym (vector) environment. """ if isinstance(env, str): @@ -51,28 +51,28 @@ class BaseAlgorithm(ABC): """ The base of RL algorithms - :param policy: (Type[BasePolicy]) Policy object - :param env: (Union[GymEnv, str, None]) The environment to learn from + :param policy: Policy object + :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: (Type[BasePolicy]) The base policy used by this method - :param learning_rate: (float or callable) learning rate for the optimizer, + :param policy_base: The base policy used by this method + :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) - :param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug - :param device: (Union[th.device, str]) Device on which the code should run. + :param policy_kwargs: Additional arguments to be passed to the policy on creation + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param verbose: The verbosity level: 0 none, 1 training information, 2 debug + :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible. - :param support_multi_env: (bool) Whether the algorithm supports training + :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param monitor_wrapper: (bool) When creating an environment, whether to wrap it + :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. - :param seed: (Optional[int]) Seed for the pseudo random generators - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param seed: Seed for the pseudo random generators + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) """ @@ -181,8 +181,8 @@ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]: """ Return the environment that will be used for evaluation. - :param eval_env: (Optional[GymEnv])) - :return: (Optional[GymEnv]) + :param eval_env:) + :return: """ if eval_env is None: eval_env = self.eval_env @@ -210,7 +210,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). - :param optimizers: (Union[List[th.optim.Optimizer], th.optim.Optimizer]) + :param optimizers: An optimizer or a list of optimizers. """ # Log the current learning rate @@ -228,7 +228,7 @@ def _excluded_save_params(self) -> List[str]: as they take up a lot of space. PyTorch variables should be excluded with this so they can be stored with ``th.save``. - :return: (List[str]) List of parameters that should be excluded from being saved with pickle. + :return: List of parameters that should be excluded from being saved with pickle. """ return [ "policy", @@ -250,7 +250,7 @@ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: "policy.optimizer" would point to ``optimizer`` object of ``self.policy`` if this object. - :return: (Tuple[List[str], List[str]]) + :return: List of Torch variables whose state dicts to save (e.g. th.nn.Modules), and list of other Torch variables to store with ``th.save``. """ @@ -267,12 +267,12 @@ def _init_callback( log_path: Optional[str] = None, ) -> BaseCallback: """ - :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. - :param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate. - :param n_eval_episodes: (int) How many episodes to play per evaluation - :param n_eval_episodes: (int) Number of episodes to rollout during evaluation. - :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved - :return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations; if None, do not evaluate. + :param n_eval_episodes: How many episodes to play per evaluation + :param n_eval_episodes: Number of episodes to rollout during evaluation. + :param log_path: Path to a folder where the evaluations will be saved + :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): @@ -310,15 +310,15 @@ def _setup_learn( """ Initialize different variables needed for training. - :param total_timesteps: (int) The total number of samples (env steps) to train on - :param eval_env: (Optional[VecEnv]) Environment to use for evaluation. - :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. - :param eval_freq: (int) How many steps between evaluations - :param n_eval_episodes: (int) How many episodes to play per evaluation - :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved - :param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute - :param tb_log_name: (str) the name of the run for tensorboard log - :return: (Tuple[int, BaseCallback]) + :param total_timesteps: The total number of samples (env steps) to train on + :param eval_env: Environment to use for evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations + :param n_eval_episodes: How many episodes to play per evaluation + :param log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute + :param tb_log_name: the name of the run for tensorboard log + :return: """ self.start_time = time.time() if self.ep_info_buffer is None or reset_num_timesteps: @@ -363,7 +363,7 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd Retrieve reward and episode length and update the buffer if using Monitor wrapper. - :param infos: ([dict]) + :param infos: """ if dones is None: dones = np.array([False] * len(infos)) @@ -379,7 +379,7 @@ def get_env(self) -> Optional[VecEnv]: """ Returns the current environment (can be None if not defined). - :return: (Optional[VecEnv]) The current environment + :return: The current environment """ return self.env @@ -387,7 +387,8 @@ def get_vec_normalize_env(self) -> Optional[VecNormalize]: """ Return the ``VecNormalize`` wrapper of the training env if it exists. - :return: Optional[VecNormalize] The ``VecNormalize`` env. + + :return: The ``VecNormalize`` env. """ return self._vec_normalize_env @@ -425,16 +426,16 @@ def learn( """ Return a trained model. - :param total_timesteps: (int) The total number of samples (env steps) to train on - :param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm. - :param log_interval: (int) The number of timesteps before logging. - :param tb_log_name: (str) the name of the run for TensorBoard logging - :param eval_env: (gym.Env) Environment that will be used to evaluate the agent - :param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) - :param n_eval_episodes: (int) Number of episode to evaluate the agent - :param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved - :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) - :return: (BaseAlgorithm) the trained model + :param total_timesteps: The total number of samples (env steps) to train on + :param callback: callback(s) called at every step with state of the algorithm. + :param log_interval: The number of timesteps before logging. + :param tb_log_name: the name of the run for TensorBoard logging + :param eval_env: Environment that will be used to evaluate the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) + :param n_eval_episodes: Number of episode to evaluate the agent + :param eval_log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) + :return: the trained model """ def predict( @@ -447,11 +448,11 @@ def predict( """ Get the model's action(s) from an observation - :param observation: (np.ndarray) the input observation - :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies) - :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state (used in recurrent policies) """ return self.policy.predict(observation, state, mask, deterministic) @@ -461,7 +462,7 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space) - :param seed: (int) + :param seed: """ if seed is None: return @@ -488,7 +489,7 @@ def set_parameters( :param exact_match: If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters. - :param device: (Union[th.device, str]) Device on which the code should run. + :param device: Device on which the code should run. """ params = None if isinstance(load_path_or_dict, dict): @@ -551,11 +552,11 @@ def load( """ Load the model from a zip-file - :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) path to the file (or a file-like) where to + :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment - :param device: (Union[th.device, str]) Device on which the code should run. + :param device: Device on which the code should run. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file(path, device=device) @@ -614,7 +615,7 @@ def get_parameters(self): Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions). - :return: (Dict[str, Dict]) Mapping of from names of the objects to PyTorch state-dicts. + :return: Mapping of from names of the objects to PyTorch state-dicts. """ state_dicts_names, _ = self._get_torch_save_params() params = {} @@ -633,7 +634,7 @@ def save( """ Save all the attributes of the object and the model parameters in a zip-file. - :param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved + :param path: path to the file where the rl agent should be saved :param exclude: name of parameters that should be excluded in addition to the default ones :param include: name of parameters that might be excluded but should be included anyway """ diff --git a/stable_baselines3/common/bit_flipping_env.py b/stable_baselines3/common/bit_flipping_env.py index b579fe157..96e3cfd4e 100644 --- a/stable_baselines3/common/bit_flipping_env.py +++ b/stable_baselines3/common/bit_flipping_env.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Optional, Union +from typing import Dict, Optional, Union import numpy as np from gym import GoalEnv, spaces @@ -14,11 +14,11 @@ class BitFlippingEnv(GoalEnv): In the continuous variant, if the ith action component has a value > 0, then the ith bit will be flipped. - :param n_bits: (int) Number of bits to flip - :param continuous: (bool) Whether to use the continuous actions version or not, + :param n_bits: Number of bits to flip + :param continuous: Whether to use the continuous actions version or not, by default, it uses the discrete one - :param max_steps: (Optional[int]) Max number of steps, by default, equal to n_bits - :param discrete_obs_space: (bool) Whether to use the discrete observation + :param max_steps: Max number of steps, by default, equal to n_bits + :param discrete_obs_space: Whether to use the discrete observation version or not, by default, it uses the MultiBinary one """ @@ -67,8 +67,8 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: """ Convert to discrete space if needed. - :param state: (np.ndarray) - :return: (np.ndarray or int) + :param state: + :return: """ if self.discrete_obs_space: # The internal state is the binary representation of the @@ -76,11 +76,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: return int(sum([state[i] * 2 ** i for i in range(len(state))])) return state - def _get_obs(self) -> OrderedDict: + def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: """ Helper to create the observation. - :return: (OrderedDict) + :return: """ return OrderedDict( [ @@ -90,7 +90,7 @@ def _get_obs(self) -> OrderedDict: ] ) - def reset(self) -> OrderedDict: + def reset(self) -> Dict[str, Union[int, np.ndarray]]: self.current_step = 0 self.state = self.obs_space.sample() return self._get_obs() diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 6c5895384..12f44811c 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -20,12 +20,12 @@ class BaseBuffer(object): """ Base class that represent a buffer (rollout or replay) - :param buffer_size: (int) Max number of element in the buffer - :param observation_space: (spaces.Space) Observation space - :param action_space: (spaces.Space) Action space - :param device: (Union[th.device, str]) PyTorch device + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device to which the values will be converted - :param n_envs: (int) Number of parallel environments + :param n_envs: Number of parallel environments """ def __init__( @@ -54,8 +54,8 @@ def swap_and_flatten(arr: np.ndarray) -> np.ndarray: to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) to [n_steps * n_envs, ...] (which maintain the order) - :param arr: (np.ndarray) - :return: (np.ndarray) + :param arr: + :return: """ shape = arr.shape if len(shape) < 3: @@ -64,7 +64,7 @@ def swap_and_flatten(arr: np.ndarray) -> np.ndarray: def size(self) -> int: """ - :return: (int) The current size of the buffer + :return: The current size of the buffer """ if self.full: return self.buffer_size @@ -93,10 +93,10 @@ def reset(self) -> None: def sample(self, batch_size: int, env: Optional[VecNormalize] = None): """ - :param batch_size: (int) Number of element to sample - :param env: (Optional[VecNormalize]) associated gym VecEnv + :param batch_size: Number of element to sample + :param env: associated gym VecEnv to normalize the observations/rewards when sampling - :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) + :return: """ upper_bound = self.buffer_size if self.full else self.pos batch_inds = np.random.randint(0, upper_bound, size=batch_size) @@ -104,9 +104,9 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None): def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None): """ - :param batch_inds: (th.Tensor) - :param env: (Optional[VecNormalize]) - :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) + :param batch_inds: + :param env: + :return: """ raise NotImplementedError() @@ -115,10 +115,10 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: Convert a numpy array to a PyTorch tensor. Note: it copies the data by default - :param array: (np.ndarray) - :param copy: (bool) Whether to copy or not the data + :param array: + :param copy: Whether to copy or not the data (may be useful to avoid changing things be reference) - :return: (th.Tensor) + :return: """ if copy: return th.tensor(array).to(self.device) @@ -141,12 +141,12 @@ class ReplayBuffer(BaseBuffer): """ Replay buffer used in off-policy algorithms like SAC/TD3. - :param buffer_size: (int) Max number of element in the buffer - :param observation_space: (spaces.Space) Observation space - :param action_space: (spaces.Space) Action space - :param device: (th.device) - :param n_envs: (int) Number of parallel environments - :param optimize_memory_usage: (bool) Enable a memory efficient variant + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param n_envs: Number of parallel environments + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -219,10 +219,10 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB as we should not sample the element with index `self.pos` See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 - :param batch_size: (int) Number of element to sample - :param env: (Optional[VecNormalize]) associated gym VecEnv + :param batch_size: Number of element to sample + :param env: associated gym VecEnv to normalize the observations/rewards when sampling - :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) + :return: """ if not self.optimize_memory_usage: return super().sample(batch_size=batch_size, env=env) @@ -254,14 +254,14 @@ class RolloutBuffer(BaseBuffer): """ Rollout buffer used in on-policy algorithms like A2C/PPO. - :param buffer_size: (int) Max number of element in the buffer - :param observation_space: (spaces.Space) Observation space - :param action_space: (spaces.Space) Action space - :param device: (th.device) - :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. - :param gamma: (float) Discount factor - :param n_envs: (int) Number of parallel environments + :param gamma: Discount factor + :param n_envs: Number of parallel environments """ def __init__( @@ -306,8 +306,8 @@ def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray where R is the discounted reward with value bootstrap, set ``gae_lambda=1.0`` during initialization. - :param last_value: (th.Tensor) - :param dones: (np.ndarray) + :param last_value: + :param dones: """ # convert to numpy @@ -330,13 +330,13 @@ def add( self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor ) -> None: """ - :param obs: (np.ndarray) Observation - :param action: (np.ndarray) Action - :param reward: (np.ndarray) - :param done: (np.ndarray) End of episode signal. - :param value: (th.Tensor) estimated value of the current state + :param obs: Observation + :param action: Action + :param reward: + :param done: End of episode signal. + :param value: estimated value of the current state following the current policy. - :param log_prob: (th.Tensor) log probability of the action + :param log_prob: log probability of the action following the current policy. """ if len(log_prob.shape) == 0: diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index d814fa6f8..05f4d14c9 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -15,7 +15,7 @@ class BaseCallback(ABC): """ Base class for callback. - :param verbose: (int) + :param verbose: """ def __init__(self, verbose: int = 0): @@ -68,7 +68,7 @@ def _on_rollout_start(self) -> None: @abstractmethod def _on_step(self) -> bool: """ - :return: (bool) If the callback returns False, training is aborted early. + :return: If the callback returns False, training is aborted early. """ return True @@ -79,7 +79,7 @@ def on_step(self) -> bool: For child callback (of an ``EventCallback``), this will be called when the event is triggered. - :return: (bool) If the callback returns False, training is aborted early. + :return: If the callback returns False, training is aborted early. """ self.n_calls += 1 # timesteps start at zero @@ -103,7 +103,7 @@ def update_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. - :param locals_: (Dict[str, Any]) the local variables during rollout collection + :param locals_: the local variables during rollout collection """ self.locals.update(locals_) self.update_child_locals(locals_) @@ -112,7 +112,7 @@ def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables on sub callbacks. - :param locals_: (Dict[str, Any]) the local variables during rollout collection + :param locals_: the local variables during rollout collection """ pass @@ -121,9 +121,9 @@ class EventCallback(BaseCallback): """ Base class for triggering callback on event. - :param callback: (Optional[BaseCallback]) Callback that will be called + :param callback: Callback that will be called when an event is triggered. - :param verbose: (int) + :param verbose: """ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): @@ -154,7 +154,7 @@ def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. - :param locals_: (Dict[str, Any]) the local variables during rollout collection + :param locals_: the local variables during rollout collection """ if self.callback is not None: self.callback.update_locals(locals_) @@ -164,7 +164,7 @@ class CallbackList(BaseCallback): """ Class for chaining callbacks. - :param callbacks: (List[BaseCallback]) A list of callbacks that will be called + :param callbacks: A list of callbacks that will be called sequentially. """ @@ -204,7 +204,7 @@ def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. - :param locals_: (Dict[str, Any]) the local variables during rollout collection + :param locals_: the local variables during rollout collection """ for callback in self.callbacks: callback.update_locals(locals_) @@ -214,9 +214,9 @@ class CheckpointCallback(BaseCallback): """ Callback for saving a model every ``save_freq`` steps - :param save_freq: (int) - :param save_path: (str) Path to the folder where the model will be saved. - :param name_prefix: (str) Common prefix to the saved models + :param save_freq: + :param save_path: Path to the folder where the model will be saved. + :param name_prefix: Common prefix to the saved models """ def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0): @@ -243,8 +243,8 @@ class ConvertCallback(BaseCallback): """ Convert functional callback (old-style) to object. - :param callback: (callable) - :param verbose: (int) + :param callback: + :param verbose: """ def __init__(self, callback, verbose=0): @@ -261,20 +261,20 @@ class EvalCallback(EventCallback): """ Callback for evaluating an agent. - :param eval_env: (Union[gym.Env, VecEnv]) The environment used for initialization - :param callback_on_new_best: (Optional[BaseCallback]) Callback to trigger + :param eval_env: The environment used for initialization + :param callback_on_new_best: Callback to trigger when there is a new best model according to the ``mean_reward`` - :param n_eval_episodes: (int) The number of episodes to test the agent - :param eval_freq: (int) Evaluate the agent every eval_freq call of the callback. - :param log_path: (str) Path to a folder where the evaluations (``evaluations.npz``) + :param n_eval_episodes: The number of episodes to test the agent + :param eval_freq: Evaluate the agent every eval_freq call of the callback. + :param log_path: Path to a folder where the evaluations (``evaluations.npz``) will be saved. It will be updated at each evaluation. - :param best_model_save_path: (str) Path to a folder where the best model + :param best_model_save_path: Path to a folder where the best model according to performance on the eval env will be saved. - :param deterministic: (bool) Whether the evaluation should + :param deterministic: Whether the evaluation should use a stochastic or deterministic actions. - :param deterministic: (bool) Whether to render or not the environment during evaluation - :param render: (bool) Whether to render or not the environment during evaluation - :param verbose: (int) + :param deterministic: Whether to render or not the environment during evaluation + :param render: Whether to render or not the environment during evaluation + :param verbose: """ def __init__( @@ -378,7 +378,7 @@ def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ Update the references to the local variables. - :param locals_: (Dict[str, Any]) the local variables during rollout collection + :param locals_: the local variables during rollout collection """ if self.callback: self.callback.update_locals(locals_) @@ -391,9 +391,9 @@ class StopTrainingOnRewardThreshold(BaseCallback): It must be used with the ``EvalCallback``. - :param reward_threshold: (float) Minimum expected reward per episode + :param reward_threshold: Minimum expected reward per episode to stop training. - :param verbose: (int) + :param verbose: """ def __init__(self, reward_threshold: float, verbose: int = 0): @@ -416,8 +416,8 @@ class EveryNTimesteps(EventCallback): """ Trigger a callback every ``n_steps`` timesteps - :param n_steps: (int) Number of timesteps between two trigger. - :param callback: (BaseCallback) Callback that will be called + :param n_steps: Number of timesteps between two trigger. + :param callback: Callback that will be called when the event is triggered. """ @@ -440,8 +440,8 @@ class StopTrainingOnMaxEpisodes(BaseCallback): For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes`` and in total for ``max_episodes * n_envs`` episodes. - :param max_episodes: (int) Maximum number of episodes to stop training. - :param verbose: (int) Select whether to print information about when training ended by reaching ``max_episodes`` + :param max_episodes: Maximum number of episodes to stop training. + :param verbose: Select whether to print information about when training ended by reaching ``max_episodes`` """ def __init__(self, max_episodes: int, verbose: int = 0): diff --git a/stable_baselines3/common/cmd_util.py b/stable_baselines3/common/cmd_util.py index af0067e82..4aede71c4 100644 --- a/stable_baselines3/common/cmd_util.py +++ b/stable_baselines3/common/cmd_util.py @@ -25,19 +25,19 @@ def make_vec_env( By default it uses a ``DummyVecEnv`` which is usually faster than a ``SubprocVecEnv``. - :param env_id: (str or Type[gym.Env]) the environment ID or the environment class - :param n_envs: (int) the number of environments you wish to have in parallel - :param seed: (int) the initial seed for the random number generator - :param start_index: (int) start rank index - :param monitor_dir: (str) Path to a folder where the monitor files will be saved. + :param env_id: the environment ID or the environment class + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. If None, no file will be written, however, the env will still be wrapped in a Monitor wrapper to provide additional information about training. - :param wrapper_class: (gym.Wrapper or callable) Additional wrapper to use on the environment. + :param wrapper_class: Additional wrapper to use on the environment. This can also be a function with single argument that wraps the environment in many things. - :param env_kwargs: (dict) Optional keyword argument to pass to the env constructor - :param vec_env_cls: (Type[VecEnv]) A custom ``VecEnv`` class constructor. Default: None. - :param vec_env_kwargs: (dict) Keyword arguments to pass to the ``VecEnv`` class constructor. - :return: (VecEnv) The wrapped environment + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :return: The wrapped environment """ env_kwargs = {} if env_kwargs is None else env_kwargs vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs @@ -90,18 +90,18 @@ def make_atari_env( Create a wrapped, monitored VecEnv for Atari. It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. - :param env_id: (str or Type[gym.Env]) the environment ID or the environment class - :param n_envs: (int) the number of environments you wish to have in parallel - :param seed: (int) the initial seed for the random number generator - :param start_index: (int) start rank index - :param monitor_dir: (str) Path to a folder where the monitor files will be saved. + :param env_id: the environment ID or the environment class + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. If None, no file will be written, however, the env will still be wrapped in a Monitor wrapper to provide additional information about training. - :param wrapper_kwargs: (Dict[str, Any]) Optional keyword argument to pass to the ``AtariWrapper`` - :param env_kwargs: (Dict[str, Any]) Optional keyword argument to pass to the env constructor - :param vec_env_cls: (Type[VecEnv]) A custom ``VecEnv`` class constructor. Default: None. - :param vec_env_kwargs: (Dict[str, Any]) Keyword arguments to pass to the ``VecEnv`` class constructor. - :return: (VecEnv) The wrapped environment + :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper`` + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :return: The wrapped environment """ if wrapper_kwargs is None: wrapper_kwargs = {} diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 44708806d..b8d78374e 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -29,7 +29,7 @@ def proba_distribution_net(self, *args, **kwargs): def proba_distribution(self, *args, **kwargs) -> "Distribution": """Set parameters of the distribution. - :return: (Distribution) self + :return: self """ @abstractmethod @@ -37,8 +37,8 @@ def log_prob(self, x: th.Tensor) -> th.Tensor: """ Returns the log likelihood - :param x: (th.Tensor) the taken action - :return: (th.Tensor) The log likelihood of the distribution + :param x: the taken action + :return: The log likelihood of the distribution """ @abstractmethod @@ -46,7 +46,7 @@ def entropy(self) -> Optional[th.Tensor]: """ Returns Shannon's entropy of the probability - :return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known + :return: the entropy, or None if no analytical form is known """ @abstractmethod @@ -54,7 +54,7 @@ def sample(self) -> th.Tensor: """ Returns a sample from the probability distribution - :return: (th.Tensor) the stochastic action + :return: the stochastic action """ @abstractmethod @@ -63,15 +63,15 @@ def mode(self) -> th.Tensor: Returns the most likely action (deterministic output) from the probability distribution - :return: (th.Tensor) the stochastic action + :return: the stochastic action """ def get_actions(self, deterministic: bool = False) -> th.Tensor: """ Return actions according to the probability distribution. - :param deterministic: (bool) - :return: (th.Tensor) + :param deterministic: + :return: """ if deterministic: return self.mode() @@ -83,7 +83,7 @@ def actions_from_params(self, *args, **kwargs) -> th.Tensor: Returns samples from the probability distribution given its parameters. - :return: (th.Tensor) actions + :return: actions """ @abstractmethod @@ -92,7 +92,7 @@ def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: Returns samples and the associated log probabilities from the probability distribution given its parameters. - :return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob + :return: actions and log prob """ @@ -101,8 +101,8 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: Continuous actions are usually considered to be independent, so we can sum components of the ``log_prob`` or the entropy. - :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,) - :return: (th.Tensor) shape: (n_batch,) + :param tensor: shape: (n_batch, n_actions) or (n_batch,) + :return: shape: (n_batch,) """ if len(tensor.shape) > 1: tensor = tensor.sum(dim=1) @@ -115,7 +115,7 @@ class DiagGaussianDistribution(Distribution): """ Gaussian distribution with diagonal covariance matrix, for continuous actions. - :param action_dim: (int) Dimension of the action space. + :param action_dim: Dimension of the action space. """ def __init__(self, action_dim: int): @@ -131,9 +131,9 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) - :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) - :param log_std_init: (float) Initial value for the log standard deviation - :return: (nn.Linear, nn.Parameter) + :param latent_dim: Dimension of the last layer of the policy (before the action layer) + :param log_std_init: Initial value for the log standard deviation + :return: """ mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: allow action dependent std @@ -144,9 +144,9 @@ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "Di """ Create the distribution given its parameters (mean, std) - :param mean_actions: (th.Tensor) - :param log_std: (th.Tensor) - :return: (DiagGaussianDistribution) + :param mean_actions: + :param log_std: + :return: """ action_std = th.ones_like(mean_actions) * log_std.exp() self.distribution = Normal(mean_actions, action_std) @@ -157,8 +157,8 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor: Get the log probabilities of actions according to the distribution. Note that you must first call the ``proba_distribution()`` method. - :param actions: (th.Tensor) - :return: (th.Tensor) + :param actions: + :return: """ log_prob = self.distribution.log_prob(actions) return sum_independent_dims(log_prob) @@ -183,9 +183,9 @@ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> T Compute the log probability of taking an action given the distribution parameters. - :param mean_actions: (th.Tensor) - :param log_std: (th.Tensor) - :return: (Tuple[th.Tensor, th.Tensor]) + :param mean_actions: + :param log_std: + :return: """ actions = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(actions) @@ -196,8 +196,8 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. - :param action_dim: (int) Dimension of the action space. - :param epsilon: (float) small value to avoid NaN due to numerical imprecision. + :param action_dim: Dimension of the action space. + :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, action_dim: int, epsilon: float = 1e-6): @@ -250,7 +250,7 @@ class CategoricalDistribution(Distribution): """ Categorical distribution for discrete actions. - :param action_dim: (int) Number of discrete actions + :param action_dim: Number of discrete actions """ def __init__(self, action_dim: int): @@ -264,9 +264,9 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: it will be the logits of the Categorical distribution. You can then get probabilities using a softmax. - :param latent_dim: (int) Dimension of the last layer + :param latent_dim: Dimension of the last layer of the policy network (before the action layer) - :return: (nn.Linear) + :return: """ action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits @@ -302,7 +302,7 @@ class MultiCategoricalDistribution(Distribution): """ MultiCategorical distribution for multi discrete actions. - :param action_dims: (List[int]) List of sizes of discrete action spaces + :param action_dims: List of sizes of discrete action spaces """ def __init__(self, action_dims: List[int]): @@ -316,9 +316,9 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: it will be the logits (flattened) of the MultiCategorical distribution. You can then get probabilities using a softmax on each sub-space. - :param latent_dim: (int) Dimension of the last layer + :param latent_dim: Dimension of the last layer of the policy network (before the action layer) - :return: (nn.Linear) + :return: """ action_logits = nn.Linear(latent_dim, sum(self.action_dims)) @@ -358,7 +358,7 @@ class BernoulliDistribution(Distribution): """ Bernoulli distribution for MultiBinary action spaces. - :param action_dim: (int) Number of binary actions + :param action_dim: Number of binary actions """ def __init__(self, action_dims: int): @@ -371,9 +371,9 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: Create the layer that represents the distribution: it will be the logits of the Bernoulli distribution. - :param latent_dim: (int) Dimension of the last layer + :param latent_dim: Dimension of the last layer of the policy network (before the action layer) - :return: (nn.Linear) + :return: """ action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits @@ -413,18 +413,18 @@ class StateDependentNoiseDistribution(Distribution): It is used to create the noise exploration matrix and compute the log probability of an action with that noise. - :param action_dim: (int) Dimension of the action space. - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param action_dim: Dimension of the action space. + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: (bool) Whether to squash the output using a tanh function, + :param squash_output: Whether to squash the output using a tanh function, this ensures bounds are satisfied. - :param learn_features: (bool) Whether to learn features for gSDE or not. + :param learn_features: Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. - :param epsilon: (float) small value to avoid NaN due to numerical imprecision. + :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__( @@ -460,8 +460,8 @@ def get_std(self, log_std: th.Tensor) -> th.Tensor: Get the standard deviation from the learned parameter (log of it by default). This ensures that the std is positive. - :param log_std: (th.Tensor) - :return: (th.Tensor) + :param log_std: + :return: """ if self.use_expln: # From gSDE paper, it allows to keep variance @@ -485,8 +485,8 @@ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: Sample weights for the noise exploration matrix, using a centered Gaussian distribution. - :param log_std: (th.Tensor) - :param batch_size: (int) + :param log_std: + :param batch_size: """ std = self.get_std(log_std) self.weights_dist = Normal(th.zeros_like(std), std) @@ -503,11 +503,11 @@ def proba_distribution_net( one output will be the deterministic action, the other parameter will be the standard deviation of the distribution that control the weights of the noise matrix. - :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) - :param log_std_init: (float) Initial value for the log standard deviation - :param latent_sde_dim: (Optional[int]) Dimension of the last layer of the feature extractor + :param latent_dim: Dimension of the last layer of the policy (before the action layer) + :param log_std_init: Initial value for the log standard deviation + :param latent_sde_dim: Dimension of the last layer of the feature extractor for gSDE. By default, it is shared with the policy network. - :return: (nn.Linear, nn.Parameter) + :return: """ # Network for the deterministic action, it represents the mean of the distribution mean_actions_net = nn.Linear(latent_dim, self.action_dim) @@ -528,10 +528,10 @@ def proba_distribution( """ Create the distribution given its parameters (mean, std) - :param mean_actions: (th.Tensor) - :param log_std: (th.Tensor) - :param latent_sde: (th.Tensor) - :return: (StateDependentNoiseDistribution) + :param mean_actions: + :param log_std: + :param latent_sde: + :return: """ # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() @@ -607,7 +607,7 @@ class TanhBijector(object): using a squashing function (tanh) TODO: use Pyro instead (https://pyro.ai/) - :param epsilon: (float) small value to avoid NaN due to numerical imprecision. + :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, epsilon: float = 1e-6): @@ -633,8 +633,8 @@ def inverse(y: th.Tensor) -> th.Tensor: """ Inverse tanh. - :param y: (th.Tensor) - :return: (th.Tensor) + :param y: + :return: """ eps = th.finfo(y.dtype).eps # Clip the action to avoid NaN @@ -651,11 +651,11 @@ def make_proba_distribution( """ Return an instance of Distribution for the correct type of action space - :param action_space: (gym.spaces.Space) the input action space - :param use_sde: (bool) Force the use of StateDependentNoiseDistribution + :param action_space: the input action space + :param use_sde: Force the use of StateDependentNoiseDistribution instead of DiagGaussianDistribution - :param dist_kwargs: (Optional[Dict[str, Any]]) Keyword arguments to pass to the probability distribution - :return: (Distribution) the appropriate Distribution object + :param dist_kwargs: Keyword arguments to pass to the probability distribution + :return: the appropriate Distribution object """ if dist_kwargs is None: dist_kwargs = {} diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 3dc358074..4a4d075c0 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -147,9 +147,9 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No Check the declared render modes and the `render()`/`close()` method of the environment. - :param env: (gym.Env) The environment to check - :param warn: (bool) Whether to output additional warnings - :param headless: (bool) Whether to disable render modes + :param env: The environment to check + :param warn: Whether to output additional warnings + :param headless: Whether to disable render modes that require a graphical interface. False by default. """ render_modes = env.metadata.get("render.modes") @@ -181,15 +181,15 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - It also optionally check that the environment is compatible with Stable-Baselines. - :param env: (gym.Env) The Gym environment that will be checked - :param warn: (bool) Whether to output additional warnings + :param env: The Gym environment that will be checked + :param warn: Whether to output additional warnings mainly related to the interaction with Stable Baselines - :param skip_render_check: (bool) Whether to skip the checks for the render method. + :param skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) """ assert isinstance( env, gym.Env - ), "You environment must inherit from gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py" + ), "Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py" # ============= Check the spaces (observation and action) ================ _check_spaces(env) diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 327300df9..8b89b0046 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -21,19 +21,19 @@ def evaluate_policy( Runs policy for ``n_eval_episodes`` episodes and returns average reward. This is made to work only with one env. - :param model: (BaseAlgorithm) The RL agent you want to evaluate. - :param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv`` + :param model: The RL agent you want to evaluate. + :param env: The gym environment. In the case of a ``VecEnv`` this must contain only one environment. - :param n_eval_episodes: (int) Number of episode to evaluate the agent - :param deterministic: (bool) Whether to use deterministic or stochastic actions - :param render: (bool) Whether to render the environment or not - :param callback: (callable) callback function to do additional checks, + :param n_eval_episodes: Number of episode to evaluate the agent + :param deterministic: Whether to use deterministic or stochastic actions + :param render: Whether to render the environment or not + :param callback: callback function to do additional checks, called after each step. - :param reward_threshold: (float) Minimum expected reward per episode, + :param reward_threshold: Minimum expected reward per episode, this will raise an error if the performance is not met - :param return_episode_rewards: (Optional[float]) If True, a list of reward per episode + :param return_episode_rewards: If True, a list of reward per episode will be returned instead of the mean. - :return: (float, float) Mean reward per episode, std of reward per episode + :return: Mean reward per episode, std of reward per episode returns ([float], [int]) when ``return_episode_rewards`` is True """ if isinstance(env, VecEnv): diff --git a/stable_baselines3/common/identity_env.py b/stable_baselines3/common/identity_env.py index 4a492e75a..0d6a74326 100644 --- a/stable_baselines3/common/identity_env.py +++ b/stable_baselines3/common/identity_env.py @@ -60,10 +60,10 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l """ Identity environment for testing purposes - :param low: (float) the lower bound of the box dim - :param high: (float) the upper bound of the box dim - :param eps: (float) the epsilon bound for correct value - :param ep_length: (int) the length of each episode in timesteps + :param low: the lower bound of the box dim + :param high: the upper bound of the box dim + :param eps: the epsilon bound for correct value + :param ep_length: the length of each episode in timesteps """ space = Box(low=low, high=high, shape=(1,), dtype=np.float32) super().__init__(ep_length=ep_length, space=space) @@ -85,8 +85,8 @@ def __init__(self, dim: int = 1, ep_length: int = 100): """ Identity environment for testing purposes - :param dim: (int) the size of the dimensions you want to learn - :param ep_length: (int) the length of each episode in timesteps + :param dim: the size of the dimensions you want to learn + :param ep_length: the length of each episode in timesteps """ space = MultiDiscrete([dim, dim]) super().__init__(ep_length=ep_length, space=space) @@ -97,8 +97,8 @@ def __init__(self, dim: int = 1, ep_length: int = 100): """ Identity environment for testing purposes - :param dim: (int) the size of the dimensions you want to learn - :param ep_length: (int) the length of each episode in timesteps + :param dim: the size of the dimensions you want to learn + :param ep_length: the length of each episode in timesteps """ space = MultiBinary(dim) super().__init__(ep_length=ep_length, space=space) @@ -108,11 +108,11 @@ class FakeImageEnv(Env): """ Fake image environment for testing purposes, it mimics Atari games. - :param action_dim: (int) Number of discrete actions - :param screen_height: (int) Height of the image - :param screen_width: (int) Width of the image - :param n_channels: (int) Number of color channels - :param discrete: (bool) + :param action_dim: Number of discrete actions + :param screen_height: Height of the image + :param screen_width: Width of the image + :param n_channels: Number of color channels + :param discrete: """ def __init__( diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index bed63273e..01e770949 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -32,9 +32,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T """ Write a dictionary to file - :param key_values: (dict) - :param key_excluded: (dict) - :param step: (int) + :param key_values: + :param key_excluded: + :param step: """ raise NotImplementedError @@ -54,7 +54,7 @@ def write_sequence(self, sequence: List): """ write_sequence an array to file - :param sequence: (list) + :param sequence: """ raise NotImplementedError @@ -64,7 +64,7 @@ def __init__(self, filename_or_file: Union[str, TextIO]): """ log to a file, in a human readable format - :param filename_or_file: (str or File) the file to write the log to + :param filename_or_file: the file to write the log to """ if isinstance(filename_or_file, str): self.file = open(filename_or_file, "wt") @@ -145,7 +145,7 @@ def __init__(self, filename: str): """ log to a file, in the JSON format - :param filename: (str) the file to write the log to + :param filename: the file to write the log to """ self.file = open(filename, "wt") @@ -178,7 +178,7 @@ def __init__(self, filename: str): """ log to a file, in a CSV format - :param filename: (str) the file to write the log to + :param filename: the file to write the log to """ self.file = open(filename, "w+t") @@ -223,7 +223,7 @@ def __init__(self, folder: str): """ Dumps key/value pairs into TensorBoard's numeric format. - :param folder: (str) the folder to write the log to + :param folder: the folder to write the log to """ assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder) @@ -257,10 +257,10 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr """ return a logger for the requested format - :param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard') - :param log_dir: (str) the logging directory - :param log_suffix: (str) the suffix for the log file - :return: (KVWriter) the logger + :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard') + :param log_dir: the logging directory + :param log_suffix: the suffix for the log file + :return: the logger """ os.makedirs(log_dir, exist_ok=True) if _format == "stdout": @@ -288,9 +288,9 @@ def record(key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] Call this once for each diagnostic quantity, each iteration If called many times, last value will be used. - :param key: (Any) save to log this key - :param value: (Any) save to log this value - :param exclude: (str or tuple) outputs to be excluded + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded """ Logger.CURRENT.record(key, value, exclude) @@ -299,9 +299,9 @@ def record_mean(key: str, value: Union[int, float], exclude: Optional[Union[str, """ The same as record(), but if called many times, values averaged. - :param key: (Any) save to log this key - :param value: (Number) save to log this value - :param exclude: (str or tuple) outputs to be excluded + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded """ Logger.CURRENT.record_mean(key, value, exclude) @@ -310,7 +310,7 @@ def record_dict(key_values: Dict[str, Any]) -> None: """ Log a dictionary of key-value pairs. - :param key_values: (dict) the list of keys and values to save to log + :param key_values: the list of keys and values to save to log """ for key, value in key_values.items(): record(key, value) @@ -327,7 +327,7 @@ def get_log_dict() -> Dict: """ get the key values logs - :return: (dict) the logged values + :return: the logged values """ return Logger.CURRENT.name_to_value @@ -340,8 +340,8 @@ def log(*args, level: int = INFO) -> None: level: int. (see logger.py docs) If the global logger level is higher than the level argument here, don't print to stdout. - :param args: (list) log the arguments - :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + :param args: log the arguments + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ Logger.CURRENT.log(*args, level=level) @@ -352,7 +352,7 @@ def debug(*args) -> None: to the console and output files (if you've configured an output file). Using the DEBUG level. - :param args: (list) log the arguments + :param args: log the arguments """ log(*args, level=DEBUG) @@ -363,7 +363,7 @@ def info(*args) -> None: to the console and output files (if you've configured an output file). Using the INFO level. - :param args: (list) log the arguments + :param args: log the arguments """ log(*args, level=INFO) @@ -374,7 +374,7 @@ def warn(*args) -> None: to the console and output files (if you've configured an output file). Using the WARN level. - :param args: (list) log the arguments + :param args: log the arguments """ log(*args, level=WARN) @@ -385,7 +385,7 @@ def error(*args) -> None: to the console and output files (if you've configured an output file). Using the ERROR level. - :param args: (list) log the arguments + :param args: log the arguments """ log(*args, level=ERROR) @@ -394,7 +394,7 @@ def set_level(level: int) -> None: """ Set logging threshold on current logger. - :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ Logger.CURRENT.set_level(level) @@ -402,7 +402,7 @@ def set_level(level: int) -> None: def get_level() -> int: """ Get logging threshold on current logger. - :return: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + :return: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ return Logger.CURRENT.level @@ -412,7 +412,7 @@ def get_dir() -> str: Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) - :return: (str) the logging directory + :return: the logging directory """ return Logger.CURRENT.get_dir() @@ -436,8 +436,8 @@ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): """ the logger class - :param folder: (str) the logging location - :param output_formats: ([str]) the list of output format + :param folder: the logging location + :param output_formats: the list of output format """ self.name_to_value = defaultdict(float) # values this iteration self.name_to_count = defaultdict(int) @@ -454,9 +454,9 @@ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, . Call this once for each diagnostic quantity, each iteration If called many times, last value will be used. - :param key: (Any) save to log this key - :param value: (Any) save to log this value - :param exclude: (str or tuple) outputs to be excluded + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded """ self.name_to_value[key] = value self.name_to_excluded[key] = exclude @@ -465,9 +465,9 @@ def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[s """ The same as record(), but if called many times, values averaged. - :param key: (Any) save to log this key - :param value: (Number) save to log this value - :param exclude: (str or tuple) outputs to be excluded + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded """ if value is None: self.name_to_value[key] = None @@ -499,8 +499,8 @@ def log(self, *args, level: int = INFO) -> None: level: int. (see logger.py docs) If the global logger level is higher than the level argument here, don't print to stdout. - :param args: (list) log the arguments - :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + :param args: log the arguments + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ if self.level <= level: self._do_log(args) @@ -511,7 +511,7 @@ def set_level(self, level: int) -> None: """ Set logging threshold on current logger. - :param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) """ self.level = level @@ -520,7 +520,7 @@ def get_dir(self) -> str: Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) - :return: (str) the logging directory + :return: the logging directory """ return self.dir @@ -537,7 +537,7 @@ def _do_log(self, args) -> None: """ log to the requested format outputs - :param args: (list) the arguments to log + :param args: the arguments to log """ for _format in self.output_formats: if isinstance(_format, SeqWriter): @@ -552,9 +552,9 @@ def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] """ configure the current logger - :param folder: (Optional[str]) the save location + :param folder: the save location (if None, $SB3_LOGDIR, if still None, tempdir/baselines-[date & time]) - :param format_strings: (Optional[List[str]]) the output logging format + :param format_strings: the output logging format (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) """ if folder is None: @@ -594,8 +594,8 @@ def __init__(self, folder: Optional[str] = None, format_strings: Optional[List[s with ScopedConfigure(folder=None, format_strings=None): {code} - :param folder: (str) the logging folder - :param format_strings: ([str]) the list of output logging format + :param folder: the logging folder + :param format_strings: the list of output logging format """ self.dir = folder self.format_strings = format_strings @@ -619,8 +619,8 @@ def read_json(filename: str) -> pandas.DataFrame: """ read a json file using pandas - :param filename: (str) the file path to read - :return: (pandas.DataFrame) the data in the json + :param filename: the file path to read + :return: the data in the json """ data = [] with open(filename, "rt") as file_handler: @@ -633,7 +633,7 @@ def read_csv(filename: str) -> pandas.DataFrame: """ read a csv file using pandas - :param filename: (str) the file path to read - :return: (pandas.DataFrame) the data in the csv + :param filename: the file path to read + :return: the data in the csv """ return pandas.read_csv(filename, index_col=None, comment="#") diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index df6719f2a..17c0587bf 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -16,12 +16,12 @@ class Monitor(gym.Wrapper): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. - :param env: (gym.Env) The environment - :param filename: (Optional[str]) the location to save a log file, can be None for no log - :param allow_early_resets: (bool) allows the reset of the environment before it is done - :param reset_keywords: (Tuple[str, ...]) extra keywords for the reset call, + :param env: The environment + :param filename: the location to save a log file, can be None for no log + :param allow_early_resets: allows the reset of the environment before it is done + :param reset_keywords: extra keywords for the reset call, if extra parameters are needed at reset - :param info_keywords: (Tuple[str, ...]) extra information to log, from the information return of env.step() + :param info_keywords: extra information to log, from the information return of env.step() """ EXT = "monitor.csv" @@ -67,7 +67,7 @@ def reset(self, **kwargs) -> np.ndarray: Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords - :return: (np.ndarray) the first observation of the environment + :return: the first observation of the environment """ if not self.allow_early_resets and not self.needs_reset: raise RuntimeError( @@ -87,8 +87,8 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, A """ Step the environment with the given action - :param action: (np.ndarray) the action - :return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information + :param action: the action + :return: observation, reward, done, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") @@ -124,7 +124,7 @@ def get_total_steps(self) -> int: """ Returns the total number of timesteps - :return: (int) + :return: """ return self.total_steps @@ -132,7 +132,7 @@ def get_episode_rewards(self) -> List[float]: """ Returns the rewards of all the episodes - :return: ([float]) + :return: """ return self.episode_rewards @@ -140,7 +140,7 @@ def get_episode_lengths(self) -> List[int]: """ Returns the number of timesteps of all the episodes - :return: ([int]) + :return: """ return self.episode_lengths @@ -148,7 +148,7 @@ def get_episode_times(self) -> List[float]: """ Returns the runtime in seconds of all the episodes - :return: ([float]) + :return: """ return self.episode_times @@ -165,8 +165,8 @@ def get_monitor_files(path: str) -> List[str]: """ get all the monitor files in the given path - :param path: (str) the logging folder - :return: ([str]) the log files + :param path: the logging folder + :return: the log files """ return glob(os.path.join(path, "*" + Monitor.EXT)) @@ -175,12 +175,12 @@ def load_results(path: str) -> pandas.DataFrame: """ Load all Monitor logs from a given directory path matching ``*monitor.csv`` - :param path: (str) the directory path containing the log file(s) - :return: (pandas.DataFrame) the logged data + :param path: the directory path containing the log file(s) + :return: the logged data """ monitor_files = get_monitor_files(path) if len(monitor_files) == 0: - raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path)) + raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") data_frames, headers = [], [] for file_name in monitor_files: with open(file_name, "rt") as file_handler: diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index f7b8b7372..a3b66db7f 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -28,8 +28,8 @@ class NormalActionNoise(ActionNoise): """ A Gaussian action noise - :param mean: (np.ndarray) the mean value of the noise - :param sigma: (np.ndarray) the scale of the noise (std here) + :param mean: the mean value of the noise + :param sigma: the scale of the noise (std here) """ def __init__(self, mean: np.ndarray, sigma: np.ndarray): @@ -50,11 +50,11 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab - :param mean: (np.ndarray) the mean of the noise - :param sigma: (np.ndarray) the scale of the noise - :param theta: (float) the rate of mean reversion - :param dt: (float) the timestep for the noise - :param initial_noise: (Optional[np.ndarray]) the initial value for the noise output, (if None: 0) + :param mean: the mean of the noise + :param sigma: the scale of the noise + :param theta: the rate of mean reversion + :param dt: the timestep for the noise + :param initial_noise: the initial value for the noise output, (if None: 0) """ def __init__( @@ -98,7 +98,7 @@ class VectorizedActionNoise(ActionNoise): A Vectorized action noise for parallel environments. :param base_noise: ActionNoise The noise generator to use - :param n_envs: (int) The number of parallel environments + :param n_envs: The number of parallel environments """ def __init__(self, base_noise: ActionNoise, n_envs: int): diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 31ec5efc8..cf08b4444 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -28,27 +28,27 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: The base policy used by this method - :param learning_rate: (float or callable) learning rate for the optimizer, + :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. - :param gradient_steps: (int) How many gradient steps to do after each rollout + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param policy_kwargs: Additional arguments to be passed to the policy on creation - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) + :param tensorboard_log: the log location for tensorboard (if None, no logging) :param verbose: The verbosity level: 0 none, 1 training information, 2 debug :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu @@ -64,9 +64,9 @@ class OffPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) - :param sde_support: (bool) Whether the model support gSDE or not + :param sde_support: Whether the model support gSDE or not """ def __init__( @@ -166,7 +166,7 @@ def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) """ Save the replay buffer as a pickle file. - :param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved. + :param path: Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary. """ assert self.replay_buffer is not None, "The replay buffer is not defined" @@ -176,7 +176,7 @@ def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) """ Load a replay buffer from a pickle file. - :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer. + :param path: Path to the pickled replay buffer. """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" @@ -281,11 +281,11 @@ def _sample_action( or sampling a random action (from a uniform distribution over the action space) or by adding noise to the deterministic output. - :param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration + :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. - :param learning_starts: (int) Number of steps before learning for the warm-up phase. - :return: (Tuple[np.ndarray, np.ndarray]) action to take in the environment + :param learning_starts: Number of steps before learning for the warm-up phase. + :return: action to take in the environment and scaled action that will be stored in the replay buffer. The two differs when the action space is not normalized (bounds are not [-1, 1]). """ @@ -358,20 +358,20 @@ def collect_rollouts( """ Collect experiences and store them into a ReplayBuffer. - :param env: (VecEnv) The training environment - :param callback: (BaseCallback) Callback that will be called at each step + :param env: The training environment + :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) - :param n_episodes: (int) Number of episodes to use to collect rollout data + :param n_episodes: Number of episodes to use to collect rollout data You can also specify a ``n_steps`` instead - :param n_steps: (int) Number of steps to use to collect rollout data + :param n_steps: Number of steps to use to collect rollout data You can also specify a ``n_episodes`` instead. - :param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration + :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. - :param learning_starts: (int) Number of steps before learning for the warm-up phase. - :param replay_buffer: (ReplayBuffer) - :param log_interval: (int) Log data every ``log_interval`` episodes - :return: (RolloutReturn) + :param learning_starts: Number of steps before learning for the warm-up phase. + :param replay_buffer: + :param log_interval: Log data every ``log_interval`` episodes + :return: """ episode_rewards, total_timesteps = [], [] total_steps, total_episodes = 0, 0 diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index e8e5bf59f..85d08c86a 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -19,33 +19,33 @@ class OnPolicyAlgorithm(BaseAlgorithm): """ The base for On-Policy algorithms (ex: A2C/PPO). - :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) The learning rate, it can be a function + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) - :param n_steps: (int) The number of steps to run for each environment per update + :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param gamma: (float) Discount factor - :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator. + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1. - :param ent_coef: (float) Entropy coefficient for the loss calculation - :param vf_coef: (float) Value function coefficient for the loss calculation - :param max_grad_norm: (float) The maximum value for the gradient clipping - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( @@ -126,12 +126,12 @@ def collect_rollouts( """ Collect rollouts using the current policy and fill a `RolloutBuffer`. - :param env: (VecEnv) The training environment - :param callback: (BaseCallback) Callback that will be called at each step + :param env: The training environment + :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) - :param rollout_buffer: (RolloutBuffer) Buffer to fill with rollouts - :param n_steps: (int) Number of experiences to collect per environment - :return: (bool) True if function returned with at least `n_rollout_steps` + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ assert self._last_obs is not None, "No previous observation was provided" diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 8b5357c71..a3f503bff 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -32,18 +32,18 @@ class BaseModel(nn.Module, ABC): In the case of policies, the prediction is an action. In the case of critics, it is the estimated value of the observation. - :param observation_space: (gym.spaces.Space) The observation space of the environment - :param action_space: (gym.spaces.Space) The action space of the environment - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param features_extractor: (nn.Module) Network to extract features + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -86,8 +86,8 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor: """ Preprocess the observation if needed and extract features. - :param obs: (th.Tensor) - :return: (th.Tensor) + :param obs: + :return: """ assert self.features_extractor is not None, "No feature extractor was set" preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) @@ -98,7 +98,7 @@ def _get_data(self) -> Dict[str, Any]: Get data that need to be saved in order to re-create the model. This corresponds to the arguments of the constructor. - :return: (Dict[str, Any]) + :return: """ return dict( observation_space=self.observation_space, @@ -114,7 +114,7 @@ def device(self) -> th.device: """Infer which device this policy lives on by inspecting its parameters. If it has no parameters, the 'auto' device is used as a fallback. - :return: (th.device)""" + :return:""" for param in self.parameters(): return param.device return get_device("auto") @@ -123,7 +123,7 @@ def save(self, path: str) -> None: """ Save model to a given location. - :param path: (str) + :param path: """ th.save({"state_dict": self.state_dict(), "data": self._get_data()}, path) @@ -132,9 +132,9 @@ def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": """ Load model from path. - :param path: (str) - :param device: (Union[th.device, str]) Device on which the policy should be loaded. - :return: (BasePolicy) + :param path: + :param device: Device on which the policy should be loaded. + :return: """ device = get_device(device) saved_variables = th.load(path, map_location=device) @@ -149,7 +149,7 @@ def load_from_vector(self, vector: np.ndarray): """ Load parameters from a 1D vector. - :param vector: (np.ndarray) + :param vector: """ th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) @@ -157,7 +157,7 @@ def parameters_to_vector(self) -> np.ndarray: """ Convert the parameters to a 1D vector. - :return: (np.ndarray) + :return: """ return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() @@ -169,7 +169,7 @@ class BasePolicy(BaseModel): :param args: positional arguments passed through to `BaseModel`. :param kwargs: keyword arguments passed through to `BaseModel`. - :param squash_output: (bool) For continuous actions, whether the output is squashed + :param squash_output: For continuous actions, whether the output is squashed or not using a ``tanh()`` function. """ @@ -206,9 +206,9 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te By default provides a dummy implementation -- not all BasePolicy classes implement this, e.g. if they are a Critic in an Actor-Critic method. - :param observation: (th.Tensor) - :param deterministic: (bool) Whether to use stochastic or deterministic actions - :return: (th.Tensor) Taken action according to the policy + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy """ def predict( @@ -222,11 +222,11 @@ def predict( Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). - :param observation: (np.ndarray) the input observation - :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies) - :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state (used in recurrent policies) """ # TODO (GH/1): add support for RNN policies @@ -281,8 +281,8 @@ def scale_action(self, action: np.ndarray) -> np.ndarray: Rescale the action from [low, high] to [-1, 1] (no need for symmetric action space) - :param action: (np.ndarray) Action to scale - :return: (np.ndarray) Scaled action + :param action: Action to scale + :return: Scaled action """ low, high = self.action_space.low, self.action_space.high return 2.0 * ((action - low) / (high - low)) - 1.0 @@ -303,32 +303,32 @@ class ActorCriticPolicy(BasePolicy): Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (Callable) Learning rate schedule (could be constant) - :param net_arch: ([int or dict]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param ortho_init: (bool) Whether to use or not orthogonal initialization - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE - :param sde_net_arch: ([int]) Network architecture for extracting features + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: (bool) Whether to squash the output using a tanh function, + :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -434,7 +434,7 @@ def reset_noise(self, n_envs: int = 1) -> None: """ Sample new weights for the exploration matrix. - :param n_envs: (int) + :param n_envs: """ assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" self.action_dist.sample_weights(self.log_std, batch_size=n_envs) @@ -453,7 +453,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None: """ Create the networks and the optimizer. - :param lr_schedule: (Callable) Learning rate schedule + :param lr_schedule: Learning rate schedule lr_schedule(1) is the initial learning rate """ self._build_mlp_extractor() @@ -508,9 +508,9 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso """ Forward pass in all the networks (actor and critic) - :param obs: (th.Tensor) Observation - :param deterministic: (bool) Whether to sample or use deterministic actions - :return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) action, value and log probability of the action + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action """ latent_pi, latent_vf, latent_sde = self._get_latent(obs) # Evaluate the values for the given observations @@ -525,8 +525,8 @@ def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: Get the latent code (i.e., activations of the last layer of each network) for the different networks. - :param obs: (th.Tensor) Observation - :return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes + :param obs: Observation + :return: Latent codes for the actor, the value function and for gSDE function """ # Preprocess the observation if needed @@ -543,9 +543,9 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optiona """ Retrieve action distribution given the latent codes. - :param latent_pi: (th.Tensor) Latent code for the actor - :param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function - :return: (Distribution) Action distribution + :param latent_pi: Latent code for the actor + :param latent_sde: Latent code for the gSDE exploration function + :return: Action distribution """ mean_actions = self.action_net(latent_pi) @@ -569,9 +569,9 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te """ Get the action according to the policy for a given observation. - :param observation: (th.Tensor) - :param deterministic: (bool) Whether to use stochastic or deterministic actions - :return: (th.Tensor) Taken action according to the policy + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy """ latent_pi, _, latent_sde = self._get_latent(observation) distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) @@ -582,9 +582,9 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso Evaluate actions according to the current policy, given the observations. - :param obs: (th.Tensor) - :param actions: (th.Tensor) - :return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ latent_pi, latent_vf, latent_sde = self._get_latent(obs) @@ -599,32 +599,32 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): CNN policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (Callable) Learning rate schedule (could be constant) - :param net_arch: ([int or dict]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param ortho_init: (bool) Whether to use or not orthogonal initialization - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE - :param sde_net_arch: ([int]) Network architecture for extracting features + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: (bool) Whether to squash the output using a tanh function, + :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -682,16 +682,16 @@ class ContinuousCritic(BaseModel): By default, it creates two critic networks used to reduce overestimation thanks to clipped Q-learning (cf TD3 paper). - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param n_critics: (int) Number of critic networks to create. + :param n_critics: Number of critic networks to create. """ def __init__( @@ -747,10 +747,10 @@ def create_sde_features_extractor( Create the neural network that will be used to extract features for the gSDE exploration function. - :param features_dim: (int) - :param sde_net_arch: ([int]) - :param activation_fn: (Type[nn.Module]) - :return: (nn.Sequential, int) + :param features_dim: + :param sde_net_arch: + :param activation_fn: + :return: """ # Special case: when using states as features (i.e. sde_net_arch is an empty list) # don't use any activation function @@ -769,9 +769,9 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ Returns the registered policy from the base type and name. See `register_policy` for registering policies and explanation. - :param base_policy_type: (Type[BasePolicy]) the base policy class - :param name: (str) the policy name - :return: (Type[BasePolicy]) the policy + :param base_policy_type: the base policy class + :param name: the policy name + :return: the policy """ if base_policy_type not in _policy_registry: raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") @@ -804,8 +804,8 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) is given and used to select and return the correct policy. - :param name: (str) the policy name - :param policy: (Type[BasePolicy]) the policy class + :param name: the policy name + :param policy: the policy class """ sub_class = None for cls in BasePolicy.__subclasses__(): diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 3efaf9b53..d96117a2e 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -15,11 +15,11 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True, Valid images: RGB, RGBD, GrayScale with values in [0, 255] - :param observation_space: (spaces.Space) - :param channels_last: (bool) - :param check_channels: (bool) Whether to do or not the check for the number of channels. + :param observation_space: + :param channels_last: + :param check_channels: Whether to do or not the check for the number of channels. e.g., with frame-stacking, the observation space may have more channels than expected. - :return: (bool) + :return: """ if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: # Check the type @@ -49,11 +49,11 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_im For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) For discrete observations, it create a one hot vector. - :param obs: (th.Tensor) Observation - :param observation_space: (spaces.Space) - :param normalize_images: (bool) Whether to normalize images or not + :param obs: Observation + :param observation_space: + :param normalize_images: Whether to normalize images or not (True by default) - :return: (th.Tensor) + :return: """ if isinstance(observation_space, spaces.Box): if is_image_space(observation_space) and normalize_images: @@ -85,8 +85,8 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: """ Get the shape of the observation (useful for the buffers). - :param observation_space: (spaces.Space) - :return: (Tuple[int, ...]) + :param observation_space: + :return: """ if isinstance(observation_space, spaces.Box): return observation_space.shape @@ -108,8 +108,8 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: Get the dimension of the observation space when flattened. It does not apply to image observation space. - :param observation_space: (spaces.Space) - :return: (int) + :param observation_space: + :return: """ # See issue https://github.com/openai/gym/issues/1915 # it may be a problem for Dict/Tuple spaces too... @@ -124,8 +124,8 @@ def get_action_dim(action_space: spaces.Space) -> int: """ Get the dimension of the action space. - :param action_space: (spaces.Space) - :return: (int) + :param action_space: + :return: """ if isinstance(action_space, spaces.Box): return int(np.prod(action_space.shape)) diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 8fe805c1c..7e5b3cd82 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -20,9 +20,9 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray: """ Apply a rolling window to a np.ndarray - :param array: (np.ndarray) the input Array - :param window: (int) length of the rolling window - :return: (np.ndarray) rolling window on the input array + :param array: the input Array + :param window: length of the rolling window + :return: rolling window on the input array """ shape = array.shape[:-1] + (array.shape[-1] - window + 1, window) strides = array.strides + (array.strides[-1],) @@ -33,11 +33,11 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl """ Apply a function to the rolling window of 2 arrays - :param var_1: (np.ndarray) variable 1 - :param var_2: (np.ndarray) variable 2 - :param window: (int) length of the rolling window - :param func: (numpy function) function to apply on the rolling window on variable 2 (such as np.mean) - :return: (Tuple[np.ndarray, np.ndarray]) the rolling output with applied function + :param var_1: variable 1 + :param var_2: variable 2 + :param window: length of the rolling window + :param func: function to apply on the rolling window on variable 2 (such as np.mean) + :return: the rolling output with applied function """ var_2_window = rolling_window(var_2, window) function_on_var2 = func(var_2_window, axis=-1) @@ -48,10 +48,10 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray """ Decompose a data frame variable to x ans ys - :param data_frame: (pd.DataFrame) the input data - :param x_axis: (str) the axis for the x and y output + :param data_frame: the input data + :param x_axis: the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') - :return: (Tuple[np.ndarray, np.ndarray]) the x and y output + :return: the x and y output """ if x_axis == X_TIMESTEPS: x_var = np.cumsum(data_frame.l.values) @@ -74,11 +74,11 @@ def plot_curves( """ plot the curves - :param xy_list: (List[Tuple[np.ndarray, np.ndarray]]) the x and y coordinates to plot - :param x_axis: (str) the axis for the x and y output + :param xy_list: the x and y coordinates to plot + :param x_axis: the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') - :param title: (str) the title of the plot - :param figsize: (Tuple[int, int]) Size of the figure (width, height) + :param title: the title of the plot + :param figsize: Size of the figure (width, height) """ plt.figure(title, figsize=figsize) @@ -104,12 +104,12 @@ def plot_results( """ Plot the results using csv files from ``Monitor`` wrapper. - :param dirs: ([str]) the save location of the results to plot - :param num_timesteps: (int or None) only plot the points below this value - :param x_axis: (str) the axis for the x and y output + :param dirs: the save location of the results to plot + :param num_timesteps: only plot the points below this value + :param x_axis: the axis for the x and y output (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs') - :param task_name: (str) the title of the task to plot - :param figsize: (Tuple[int, int]) Size of the figure (width, height) + :param task_name: the title of the task to plot + :param figsize: Size of the figure (width, height) """ data_frames = [] diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index b98fccc65..d10a775ba 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -9,8 +9,8 @@ def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): Calulates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - :param epsilon: (float) helps with arithmetic issues - :param shape: (tuple) the shape of the data stream's output + :param epsilon: helps with arithmetic issues + :param shape: the shape of the data stream's output """ self.mean = np.zeros(shape, np.float64) self.var = np.ones(shape, np.float64) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 3f862690d..f59cdd649 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -29,9 +29,9 @@ def recursive_getattr(obj: Any, attr: str, *args) -> Any: Ex: > MyObject.sub_object = SubObject(name='test') > recursive_getattr(MyObject, 'sub_object.name') # return test - :param obj: (Any) - :param attr: (str) Attribute to retrieve - :return: (Any) The attribute + :param obj: + :param attr: Attribute to retrieve + :return: The attribute """ def _getattr(obj: Any, attr: str) -> Any: @@ -48,9 +48,9 @@ def recursive_setattr(obj: Any, attr: str, val: Any) -> None: Ex: > MyObject.sub_object = SubObject(name='test') > recursive_setattr(MyObject, 'sub_object.name', 'hello') - :param obj: (Any) - :param attr: (str) Attribute to set - :param val: (Any) New value of the attribute + :param obj: + :param attr: Attribute to set + :param val: New value of the attribute """ pre, _, post = attr.rpartition(".") return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) @@ -60,8 +60,8 @@ def is_json_serializable(item: Any) -> bool: """ Test if an object is serializable into JSON - :param item: (object) The object to be tested for JSON serialization. - :return: (bool) True if object is JSON serializable, false otherwise. + :param item: The object to be tested for JSON serialization. + :return: True if object is JSON serializable, false otherwise. """ # Try with try-except struct. json_serializable = True @@ -76,11 +76,11 @@ def data_to_json(data: Dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing - :param data: (Dict[str, Any]) Dictionary of class parameters to be + :param data: Dictionary of class parameters to be stored. Items that are not JSON serializable will be pickled with Cloudpickle and stored as bytearray in the JSON file - :return: (str) JSON string of the data serialized. + :return: JSON string of the data serialized. """ # First, check what elements can not be JSONfied, # and turn them into byte-strings @@ -131,15 +131,15 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No """ Turn JSON serialization of class-parameters back into dictionary. - :param json_string: (str) JSON serialization of the class-parameters + :param json_string: JSON serialization of the class-parameters that should be loaded. - :param custom_objects: (dict) Dictionary of objects to replace + :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in `keras.models.load_model`. Useful when you have an object in file that can not be deserialized. - :return: (dict) Loaded class parameters. + :return: Loaded class parameters. """ if custom_objects is not None and not isinstance(custom_objects, dict): raise ValueError("custom_objects argument must be a dict or None") @@ -189,12 +189,12 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb points to a folder, it changes the path to path_2. If the path already exists and verbose == 2, it raises a warning. - :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. - :param mode: (str) how to open the file. "w"|"write" for writing, "r"|"read" for reading. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. - :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + :param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. """ @@ -219,11 +219,11 @@ def open_path_str(path: str, mode: str, verbose=0, suffix=None) -> io.BufferedIO Open a path given by a string. If writing to the path, the function ensures that the path exists. - :param path: (str) the path to open. If mode is "w" then it ensures that the path exists + :param path: the path to open. If mode is "w" then it ensures that the path exists by creating the necessary folders and renaming path if it points to a folder. - :param mode: (str) how to open the file. "w" for writing, "r" for reading. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. - :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + :param mode: how to open the file. "w" for writing, "r" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. """ @@ -236,12 +236,12 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose=0, suffix=None) -> Open a path given by a string. If writing to the path, the function ensures that the path exists. - :param path: (pathlib.Path) the path to check. If mode is "w" then it + :param path: the path to check. If mode is "w" then it ensures that the path exists by creating the necessary folders and renaming path if it points to a folder. - :param mode: (str) how to open the file. "w" for writing, "r" for reading. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. - :param suffix: (str) The preferred suffix. If mode is "w" then the opened file has the suffix. + :param mode: how to open the file. "w" for writing, "r" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. If mode is "r" then we attempt to open the path. If an error is raised and the suffix is not None, we attempt to open the path with the suffix. """ @@ -291,13 +291,13 @@ def save_to_zip_file( """ Save model data to a zip archive. - :param save_path: (Union[str, pathlib.Path, io.BufferedIOBase]) Where to store the model. + :param save_path: Where to store the model. if save_path is a str or pathlib.Path ensures that the path actually exists. :param data: Class parameters being stored (non-PyTorch variables) :param params: Model parameters being stored expected to contain an entry for every state_dict with its name and the state_dict. :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information """ save_path = open_path(save_path, "w", verbose=0, suffix="zip") # data/params can be None, so do not @@ -327,11 +327,11 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose= If the path exists and is a directory, it will raise a warning and rename the path. If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'. - :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. :param obj: The object to save. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. """ with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler: pickle.dump(obj, file_handler) @@ -342,10 +342,10 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) Load an object from the path. If a suffix is provided in the path, it will use that suffix. If the path does not exist, it will attempt to load using the .pkl suffix. - :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) the path to open. + :param path: the path to open. if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the path actually exists. If path is a io.BufferedIOBase the path exists. - :param verbose: (int) Verbosity level, 0 means only warnings, 2 means debug information. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. """ with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler: return pickle.load(file_handler) @@ -360,11 +360,11 @@ def load_from_zip_file( """ Load model data from a .zip archive - :param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from + :param load_path: Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) - :param device: (Union[th.device, str]) Device on which the code should run. - :return: (dict),(dict),(dict) Class parameters, model state_dicts (aka "params", dict of state_dict) + :param device: Device on which the code should run. + :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 07449e9cb..43e7c5d18 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -13,8 +13,8 @@ class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. - :param observation_space: (gym.Space) - :param features_dim: (int) Number of features extracted. + :param observation_space: + :param features_dim: Number of features extracted. """ def __init__(self, observation_space: gym.Space, features_dim: int = 0): @@ -36,7 +36,7 @@ class FlattenExtractor(BaseFeaturesExtractor): Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. - :param observation_space: (gym.Space) + :param observation_space: """ def __init__(self, observation_space: gym.Space): @@ -54,8 +54,8 @@ class NatureCNN(BaseFeaturesExtractor): "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. - :param observation_space: (gym.Space) - :param features_dim: (int) Number of features extracted. + :param observation_space: + :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. """ @@ -96,16 +96,16 @@ def create_mlp( Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. - :param input_dim: (int) Dimension of the input vector - :param output_dim: (int) - :param net_arch: (List[int]) Architecture of the neural net + :param input_dim: Dimension of the input vector + :param output_dim: + :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. - :param activation_fn: (Type[nn.Module]) The activation function + :param activation_fn: The activation function to use after each layer. - :param squash_output: (bool) Whether to squash the output using a Tanh + :param squash_output: Whether to squash the output using a Tanh activation function - :return: (List[nn.Module]) + :return: """ if len(net_arch) > 0: @@ -145,11 +145,11 @@ class MlpExtractor(nn.Module): Adapted from Stable Baselines. - :param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN) - :param net_arch: ([int or dict]) The specification of the policy and value networks. + :param feature_dim: Dimension of the feature vector (can be the output of a CNN) + :param net_arch: The specification of the policy and value networks. See above for details on its formatting. - :param activation_fn: (Type[nn.Module]) The activation function to use for the networks. - :param device: (th.device) + :param activation_fn: The activation function to use for the networks. + :param device: """ def __init__( @@ -214,7 +214,7 @@ def __init__( def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ - :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. + :return: latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ shared_latent = self.shared_net(features) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index aeab67b22..deb7bdb24 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -23,8 +23,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators - :param seed: (int) - :param using_cuda: (bool) + :param seed: + :param using_cuda: """ # Seed python RNG random.seed(seed) @@ -50,9 +50,9 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: ev=1 => perfect prediction ev<0 => worse than just predicting zero - :param y_pred: (np.ndarray) the prediction - :param y_true: (np.ndarray) the expected value - :return: (float) explained variance of ypred and y + :param y_pred: the prediction + :param y_true: the expected value + :return: explained variance of ypred and y """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) @@ -64,8 +64,8 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> Update the learning rate for a given optimizer. Useful when doing linear schedule. - :param optimizer: (th.optim.Optimizer) - :param learning_rate: (float) + :param optimizer: + :param learning_rate: """ for param_group in optimizer.param_groups: param_group["lr"] = learning_rate @@ -76,8 +76,8 @@ def get_schedule_fn(value_schedule: Union[Callable, float]) -> Callable: Transform (if needed) learning rate and clip range (for PPO) to callable. - :param value_schedule: (callable or float) - :return: (function) + :param value_schedule: + :return: """ # If the passed schedule is a float # create a constant function @@ -96,12 +96,12 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Callable: This is used in DQN for linearly annealing the exploration fraction (epsilon for the epsilon-greedy strategy). - :params start: (float) value to start with if ``progress_remaining`` = 1 - :params end: (float) value to end with if ``progress_remaining`` = 0 - :params end_fraction: (float) fraction of ``progress_remaining`` + :params start: value to start with if ``progress_remaining`` = 1 + :params end: value to end with if ``progress_remaining`` = 0 + :params end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1 then end is reached after 10% of the complete training process. - :return: (Callable) + :return: """ def func(progress_remaining: float) -> float: @@ -118,8 +118,8 @@ def constant_fn(val: float) -> Callable: Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) - :param val: (float) - :return: (Callable) + :param val: + :return: """ def func(_): @@ -135,8 +135,8 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: For now, it supports only cpu and cuda. By default, it tries to use the gpu. - :param device: (Union[str, th.device]) One for 'auto', 'cuda', 'cpu' - :return: (th.device) + :param device: One for 'auto', 'cuda', 'cpu' + :return: """ # Cuda by default if device == "auto": @@ -156,7 +156,7 @@ def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. - :return: (int) latest run number + :return: latest run number """ max_run_id = 0 for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"): @@ -173,9 +173,9 @@ def configure_logger( """ Configure the logger's outputs. - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param tb_log_name: (str) tensorboard log + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param tb_log_name: tensorboard log """ if tensorboard_log is not None and SummaryWriter is not None: latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name) @@ -199,9 +199,9 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a - observation_space - action_space - :param env: (GymEnv) Environment to check for valid spaces - :param observation_space: (gym.spaces.Space) Observation space to check against - :param action_space: (gym.spaces.Space) Action space to check against + :param env: Environment to check for valid spaces + :param observation_space: Observation space to check against + :param action_space: Action space to check against """ if ( observation_space != env.observation_space @@ -221,9 +221,9 @@ def is_vectorized_observation(observation: np.ndarray, observation_space: gym.sp For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized. - :param observation: (np.ndarray) the input observation to validate - :param observation_space: (gym.spaces) the observation space - :return: (bool) whether the given observation is vectorized or not + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not """ if isinstance(observation_space, gym.spaces.Box): if observation.shape == observation_space.shape: @@ -298,9 +298,9 @@ def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th. params (in place). See https://github.com/DLR-RM/stable-baselines3/issues/93 - :param params: (Iterable[th.nn.Parameter]) parameters to use to update the target params - :param target_params: (Iterable[th.nn.Parameter]) parameters to update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) + :param params: parameters to use to update the target params + :param target_params: parameters to update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) """ with th.no_grad(): for param, target_param in zip(params, target_params): diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index d1dfeb135..1940f20c0 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -21,9 +21,9 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec """ Retrieve a ``VecEnvWrapper`` object by recursively searching. - :param env: (gym.Env) - :param vec_wrapper_class: (VecEnvWrapper) - :return: (VecEnvWrapper) + :param env: + :param vec_wrapper_class: + :return: """ env_tmp = env while isinstance(env_tmp, VecEnvWrapper): @@ -35,8 +35,8 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: """ - :param env: (gym.Env) - :return: (VecNormalize) + :param env: + :return: """ return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type @@ -46,8 +46,8 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: """ Sync eval env and train env when using VecNormalize - :param env: (GymEnv) - :param eval_env: (GymEnv) + :param env: + :param eval_env: """ env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 36b4f7b14..9b542454e 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -25,9 +25,9 @@ def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cov (P,Q) are chosen to be as close as possible, and if N is square, then P=Q. - :param img_nhwc: (Sequence[np.ndarray]) list or array of images, ndim=4 once turned into array. img nhwc + :param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc n = batch index, h = height, w = width, c = channel - :return: (np.ndarray) img_HWc, ndim=3 + :return: img_HWc, ndim=3 """ img_nhwc = np.asarray(img_nhwc) n_images, height, width, n_channels = img_nhwc.shape @@ -49,9 +49,9 @@ class VecEnv(ABC): """ An abstract asynchronous, vectorized environment. - :param num_envs: (int) the number of environments - :param observation_space: (gym.spaces.Space) the observation space - :param action_space: (gym.spaces.Space) the action space + :param num_envs: the number of environments + :param observation_space: the observation space + :param action_space: the action space """ metadata = {"render.modes": ["human", "rgb_array"]} @@ -71,7 +71,7 @@ def reset(self) -> VecEnvObs: be cancelled and step_wait() should not be called until step_async() is invoked again. - :return: (VecEnvObs) observation + :return: observation """ raise NotImplementedError() @@ -108,9 +108,9 @@ def get_attr(self, attr_name: str, indices: "VecEnvIndices" = None) -> List[Any] """ Return attribute from vectorized environment. - :param attr_name: (str) The name of the attribute whose value to return - :param indices: (list,int) Indices of envs to get attribute from - :return: (list) List of values of 'attr_name' in all environments + :param attr_name: The name of the attribute whose value to return + :param indices: Indices of envs to get attribute from + :return: List of values of 'attr_name' in all environments """ raise NotImplementedError() @@ -119,10 +119,10 @@ def set_attr(self, attr_name: str, value: Any, indices: "VecEnvIndices" = None) """ Set attribute inside vectorized environments. - :param attr_name: (str) The name of attribute to assign new value - :param value: (obj) Value to assign to `attr_name` - :param indices: (list,int) Indices of envs to assign value - :return: (NoneType) + :param attr_name: The name of attribute to assign new value + :param value: Value to assign to `attr_name` + :param indices: Indices of envs to assign value + :return: """ raise NotImplementedError() @@ -131,11 +131,11 @@ def env_method(self, method_name: str, *method_args, indices: "VecEnvIndices" = """ Call instance methods of vectorized environments. - :param method_name: (str) The name of the environment method to invoke. - :param indices: (list,int) Indices of envs whose method to call - :param method_args: (tuple) Any positional arguments to provide in the call - :param method_kwargs: (dict) Any keyword arguments to provide in the call - :return: (list) List of items returned by the environment's method call + :param method_name: The name of the environment method to invoke. + :param indices: Indices of envs whose method to call + :param method_args: Any positional arguments to provide in the call + :param method_kwargs: Any keyword arguments to provide in the call + :return: List of items returned by the environment's method call """ raise NotImplementedError() @@ -143,8 +143,8 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: """ Step the environments with the given action - :param actions: (np.ndarray) the action - :return: (VecEnvStepReturn) observation, reward, done, information + :param actions: the action + :return: observation, reward, done, information """ self.step_async(actions) return self.step_wait() @@ -185,8 +185,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: Sets the random seeds for all environments, based on a given seed. Each individual environment will still get its own seed, by incrementing the given seed. - :param seed: (Optional[int]) The random seed. May be None for completely random seeding. - :return: (List[Union[None, int]]) Returns a list containing the seeds for each individual env. + :param seed: The random seed. May be None for completely random seeding. + :return: Returns a list containing the seeds for each individual env. Note that all list elements may be None, if the env does not return anything when being seeded. """ pass @@ -201,9 +201,9 @@ def unwrapped(self) -> "VecEnv": def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: """Check if an attribute reference is being hidden in a recursive call to __getattr__ - :param name: (str) name of attribute to check for - :param already_found: (bool) whether this attribute has already been found in a wrapper - :return: (Optional[str]) name of module whose attribute is being shadowed, if any. + :param name: name of attribute to check for + :param already_found: whether this attribute has already been found in a wrapper + :return: name of module whose attribute is being shadowed, if any. """ if hasattr(self, name) and already_found: return f"{type(self).__module__}.{type(self).__name__}" @@ -214,8 +214,8 @@ def _get_indices(self, indices: "VecEnvIndices") -> Iterable[int]: """ Convert a flexibly-typed reference to environment indices to an implied list of indices. - :param indices: (None,int,Iterable) refers to indices of envs. - :return: (list) the implied list of indices. + :param indices: refers to indices of envs. + :return: the implied list of indices. """ if indices is None: indices = range(self.num_envs) @@ -228,9 +228,9 @@ class VecEnvWrapper(VecEnv): """ Vectorized environment base class - :param venv: (VecEnv) the vectorized environment to wrap - :param observation_space: (Optional[gym.spaces.Space]) the observation space (can be None to load from venv) - :param action_space: (Optional[gym.spaces.Space]) the action space (can be None to load from venv) + :param venv: the vectorized environment to wrap + :param observation_space: the observation space (can be None to load from venv) + :param action_space: the action space (can be None to load from venv) """ def __init__( @@ -299,7 +299,7 @@ def __getattr__(self, name: str) -> Any: def _get_all_attributes(self) -> Dict[str, Any]: """Get all (inherited) instance and class attributes - :return: (Dict[str, Any]) all_attributes + :return: all_attributes """ all_attributes = self.__dict__.copy() all_attributes.update(self.class_attributes) @@ -308,8 +308,8 @@ def _get_all_attributes(self) -> Dict[str, Any]: def getattr_recursive(self, name: str): """Recursively check wrappers to find attribute. - :param name (str) name of attribute to look for - :return: (object) attribute + :param name: name of attribute to look for + :return: attribute """ all_attributes = self._get_all_attributes() if name in all_attributes: # attribute is present in this wrapper @@ -326,7 +326,7 @@ def getattr_recursive(self, name: str): def getattr_depth_check(self, name: str, already_found: bool): """See base class. - :return: (str or None) name of module whose attribute is being shadowed, if any. + :return: name of module whose attribute is being shadowed, if any. """ all_attributes = self._get_all_attributes() if name in all_attributes and already_found: @@ -346,7 +346,7 @@ class CloudpickleWrapper: """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) - :param var: (Any) the variable you wish to wrap for pickling with cloudpickle + :param var: the variable you wish to wrap for pickling with cloudpickle """ def __init__(self, var: Any): diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 95bdc1a57..1b34d2d34 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -17,7 +17,7 @@ class DummyVecEnv(VecEnv): This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. - :param env_fns: (List[Callable[[], gym.Env]]) a list of functions + :param env_fns: a list of functions that return environments to vectorize """ diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 937e94365..0f5d8bb3a 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -65,8 +65,8 @@ class SubprocVecEnv(VecEnv): ``if __name__ == "__main__":`` block. For more information, see the multiprocessing documentation. - :param env_fns: ([Gym Environment]) Environments to run in subprocesses - :param start_method: (str) method used to start the subprocesses. + :param env_fns: Environments to run in subprocesses + :param start_method: method used to start the subprocesses. Must be one of the methods returned by multiprocessing.get_all_start_methods(). Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. """ @@ -167,8 +167,8 @@ def _get_target_remotes(self, indices): Get the connection object needed to communicate with the wanted envs that are in subprocesses. - :param indices: (None,int,Iterable) refers to indices of envs. - :return: ([multiprocessing.Connection]) Connection object to communicate between processes. + :param indices: refers to indices of envs. + :return: Connection object to communicate between processes. """ indices = self._get_indices(indices) return [self.remotes[i] for i in indices] diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 0ebecb05b..c38c964b7 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -1,35 +1,36 @@ """ Helpers for dealing with vectorized environments. """ - from collections import OrderedDict +from typing import Any, Dict, List, Tuple, Union import gym import numpy as np -def copy_obs_dict(obs): +def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """ Deep-copy a dict of numpy arrays. - :param obs: (OrderedDict): a dict of numpy arrays. - :return (OrderedDict) a dict of copied numpy arrays. + :param obs: a dict of numpy arrays. + :return: a dict of copied numpy arrays. """ assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) -def dict_to_obs(space, obs_dict): +def dict_to_obs( + space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray] +) -> Union[Dict[Any, np.ndarray], Tuple[np.ndarray, ...], np.ndarray]: """ Convert an internal representation raw_obs into the appropriate type specified by space. - :param space: (gym.spaces.Space) an observation space. - :param obs_dict: (OrderedDict) a dict of numpy arrays. - :return (ndarray, tuple or dict): returns an observation - of the same type as space. If space is Dict, function is identity; - if space is Tuple, converts dict to Tuple; otherwise, space is - unstructured and returns the value raw_obs[None]. + :param space: an observation space. + :param obs_dict: a dict of numpy arrays. + :return: returns an observation of the same type as space. + If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; + otherwise, space is unstructured and returns the value raw_obs[None]. """ if isinstance(space, gym.spaces.Dict): return obs_dict @@ -41,7 +42,7 @@ def dict_to_obs(space, obs_dict): return obs_dict[None] -def obs_space_info(obs_space): +def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: """ Get dict-structured information about a gym.Space. @@ -49,8 +50,8 @@ def obs_space_info(obs_space): Tuple spaces are converted into a dict with keys indexing into the tuple. Unstructured spaces are represented by {None: obs_space}. - :param obs_space: (gym.spaces.Space) an observation space - :return (tuple) A tuple (keys, shapes, dtypes): + :param obs_space: an observation space + :return: A tuple (keys, shapes, dtypes): keys: a list of dict keys. shapes: a dict mapping keys to shapes. dtypes: a dict mapping keys to dtypes. diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index 613be0da7..c943cb266 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -10,10 +10,10 @@ class VecCheckNan(VecEnvWrapper): NaN and inf checking wrapper for vectorized environment, will raise a warning by default, allowing you to know from what the NaN of inf originated from. - :param venv: (VecEnv) the vectorized environment to wrap - :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning - :param warn_once: (bool) Whether or not to only warn once. - :param check_inf: (bool) Whether or not to check for +inf or -inf as well + :param venv: the vectorized environment to wrap + :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning + :param warn_once: Whether or not to only warn once. + :param check_inf: Whether or not to check for +inf or -inf as well """ def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 213dd6a83..0c407485b 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -11,14 +11,14 @@ class VecNormalize(VecEnvWrapper): A moving average, normalizing wrapper for vectorized environment. has support for saving/loading moving average, - :param venv: (VecEnv) the vectorized environment to wrap - :param training: (bool) Whether to update or not the moving average - :param norm_obs: (bool) Whether to normalize observation or not (default: True) - :param norm_reward: (bool) Whether to normalize rewards or not (default: True) - :param clip_obs: (float) Max absolute value for observation - :param clip_reward: (float) Max value absolute for discounted reward - :param gamma: (float) discount factor - :param epsilon: (float) To avoid division by zero + :param venv: the vectorized environment to wrap + :param training: Whether to update or not the moving average + :param norm_obs: Whether to normalize observation or not (default: True) + :param norm_reward: Whether to normalize rewards or not (default: True) + :param clip_obs: Max absolute value for observation + :param clip_reward: Max value absolute for discounted reward + :param gamma: discount factor + :param epsilon: To avoid division by zero """ def __init__( @@ -58,7 +58,7 @@ def __setstate__(self, state): User must call set_venv() after unpickling before using. - :param state: (dict)""" + :param state:""" self.__dict__.update(state) assert "venv" not in state self.venv = None @@ -69,7 +69,7 @@ def set_venv(self, venv): Also sets attributes derived from this such as `num_env`. - :param venv: (VecEnv) + :param venv: """ if self.venv is not None: raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") @@ -162,9 +162,9 @@ def load(load_path: str, venv: VecEnv) -> "VecNormalize": """ Loads a saved VecNormalize object. - :param load_path: (str) the path to load from. - :param venv: (VecEnv) the VecEnv to wrap. - :return: (VecNormalize) + :param load_path: the path to load from. + :param venv: the VecEnv to wrap. + :return: """ with open(load_path, "rb") as file_handler: vec_normalize = pickle.load(file_handler) @@ -176,7 +176,7 @@ def save(self, save_path: str) -> None: Save current VecNormalize object with all running statistics and settings (e.g. clip_obs) - :param save_path: (str) The path to save to + :param save_path: The path to save to """ with open(save_path, "wb") as file_handler: pickle.dump(self, file_handler) diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 64b92ebb0..343a994c8 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -10,7 +10,7 @@ class VecTransposeImage(VecEnvWrapper): Re-order channels, from HxWxC to CxHxW. It is required for PyTorch convolution layers. - :param venv: (VecEnv) + :param venv: """ def __init__(self, venv: VecEnv): @@ -24,8 +24,8 @@ def transpose_space(observation_space: spaces.Box) -> spaces.Box: """ Transpose an observation space (re-order channels). - :param observation_space: (spaces.Box) - :return: (spaces.Box) + :param observation_space: + :return: """ assert is_image_space(observation_space), "The observation space must be an image" width, height, channels = observation_space.shape @@ -37,8 +37,8 @@ def transpose_image(image: np.ndarray) -> np.ndarray: """ Transpose an image or batch of images (re-order channels). - :param image: (np.ndarray) - :return: (np.ndarray) + :param image: + :return: """ if len(image.shape) == 3: return np.transpose(image, (2, 0, 1)) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index f7b6ffac7..fc3be1b0b 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -15,13 +15,13 @@ class VecVideoRecorder(VecEnvWrapper): Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. - :param venv: (VecEnv or VecEnvWrapper) - :param video_folder: (str) Where to save videos - :param record_video_trigger: (func) Function that defines when to start recording. + :param venv: + :param video_folder: Where to save videos + :param record_video_trigger: Function that defines when to start recording. The function takes the current number of step, and returns whether we should start recording or not. - :param video_length: (int) Length of recorded videos - :param name_prefix: (str) Prefix to the video name + :param video_length: Length of recorded videos + :param name_prefix: Prefix to the video name """ def __init__(self, venv, video_folder, record_video_trigger, video_length=200, name_prefix="rl-video"): diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index cc6e2899f..8acd0f824 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -19,36 +19,36 @@ class DDPG(TD3): Note: we treat DDPG as a special case of its successor TD3. - :param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. - :param gradient_steps: (int) How many gradient steps to do after each rollout + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 8cfcca86c..29323c087 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -19,40 +19,40 @@ class DQN(OffPolicyAlgorithm): Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. - :param policy: (DQNPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) The learning rate, it can be a function + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function of the current progress (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. - :param gradient_steps: (int) How many gradient steps to do after each rollout + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param target_update_interval: (int) update the target network every ``target_update_interval`` + :param target_update_interval: update the target network every ``target_update_interval`` environment steps. - :param exploration_fraction: (float) fraction of entire training period over which the exploration rate is reduced - :param exploration_initial_eps: (float) initial value of random action probability - :param exploration_final_eps: (float) final value of random action probability - :param max_grad_norm: (float) The maximum value for the gradient clipping - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( @@ -192,11 +192,11 @@ def predict( """ Overrides the base_class predict function to include epsilon-greedy exploration. - :param observation: (np.ndarray) the input observation - :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies) - :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index f8ec737de..890d3ed89 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -12,11 +12,11 @@ class QNetwork(BasePolicy): """ Action-Value (Q-Value) network for DQN - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, + :param observation_space: Observation space + :param action_space: Action space + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -53,8 +53,8 @@ def forward(self, obs: th.Tensor) -> th.Tensor: """ Predict the q-values. - :param obs: (th.Tensor) Observation - :return: (th.Tensor) The estimated Q-Value for each action. + :param obs: Observation + :return: The estimated Q-Value for each action. """ return self.q_net(self.extract_features(obs)) @@ -83,19 +83,19 @@ class DQNPolicy(BasePolicy): """ Policy class with Q-Value Net and target net for DQN - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ @@ -146,7 +146,7 @@ def _build(self, lr_schedule: Callable) -> None: """ Create the network and the optimizer. - :param lr_schedule: (Callable) Learning rate schedule + :param lr_schedule: Learning rate schedule lr_schedule(1) is the initial learning rate """ @@ -193,17 +193,17 @@ class CnnPolicy(DQNPolicy): """ Policy class for DQN when using images as input. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param normalize_images: (bool) Whether to normalize images or not, + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 2958f708b..47420cf2d 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -23,43 +23,43 @@ class PPO(OnPolicyAlgorithm): Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html - :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) The learning rate, it can be a function + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0) - :param n_steps: (int) The number of steps to run for each environment per update + :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param batch_size: (int) Minibatch size - :param n_epochs: (int) Number of epoch when optimizing the surrogate loss - :param gamma: (float) Discount factor - :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator - :param clip_range: (float or callable) Clipping parameter, it can be a function of the current progress + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0). - :param clip_range_vf: (float or callable) Clipping parameter for the value function, + :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. - :param ent_coef: (float) Entropy coefficient for the loss calculation - :param vf_coef: (float) Value function coefficient for the loss calculation - :param max_grad_norm: (float) The maximum value for the gradient clipping - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param target_kl: (float) Limit the KL divergence between updates, + :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div. - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 6dd98e18a..c11c43707 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -18,25 +18,25 @@ class Actor(BasePolicy): """ Actor network (policy) for SAC. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param full_std: (bool) Whether to use (n_features x n_actions) parameters + :param features_dim: Number of features + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) when using gSDE. - :param sde_net_arch: ([int]) Network architecture for extracting features + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -132,7 +132,7 @@ def get_std(self) -> th.Tensor: but is slightly different when using ``expln`` function (cf StateDependentNoiseDistribution doc). - :return: (th.Tensor) + :return: """ msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -142,7 +142,7 @@ def reset_noise(self, batch_size: int = 1) -> None: """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ msg = "reset_noise() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg @@ -152,8 +152,8 @@ def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, """ Get the parameters for the action distribution. - :param obs: (th.Tensor) - :return: (Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]) + :param obs: + :return: Mean, standard deviation and optional keyword arguments. """ features = self.extract_features(obs) @@ -189,30 +189,30 @@ class SACPolicy(BasePolicy): """ Policy class (with both actor and critic) for SAC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param n_critics: (int) Number of critic networks to create. + :param n_critics: Number of critic networks to create. """ def __init__( @@ -321,7 +321,7 @@ def reset_noise(self, batch_size: int = 1) -> None: """ Sample new weights for the exploration matrix, when using gSDE. - :param batch_size: (int) + :param batch_size: """ self.actor.reset_noise(batch_size=batch_size) @@ -345,28 +345,28 @@ class CnnPolicy(SACPolicy): """ Policy class (with both actor and critic) for SAC. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param use_sde: (bool) Whether to use State Dependent Exploration or not - :param log_std_init: (float) Initial value for the log standard deviation - :param sde_net_arch: ([int]) Network architecture for extracting features + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features. - :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability. - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param normalize_images: (bool) Whether to normalize images or not, + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param n_critics: (int) Number of critic networks to create. + :param n_critics: Number of critic networks to create. """ def __init__( diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bbc7eb14a..785cc9ddb 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -26,48 +26,48 @@ class SAC(OffPolicyAlgorithm): Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270 - :param policy: (SACPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. - :param gradient_steps: (int) How many gradient steps to do after each rollout + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param ent_coef: (str or float) Entropy regularization coefficient. (Equivalent to + :param ent_coef: Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value) - :param target_update_interval: (int) update the target network every ``target_network_update_freq`` + :param target_update_interval: update the target network every ``target_network_update_freq`` gradient steps. - :param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) - :param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE) + :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``) + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) - :param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling + :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) - :param create_eval_env: (bool) Whether to create a second environment that will be + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__( diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index c5cdd8e41..401a54108 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -13,14 +13,14 @@ class Actor(BasePolicy): """ Actor network (policy) for TD3. - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) """ @@ -79,21 +79,21 @@ class TD3Policy(BasePolicy): """ Policy class (with both actor and critic) for TD3. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (Callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param n_critics: (int) Number of critic networks to create. + :param n_critics: Number of critic networks to create. """ def __init__( @@ -195,21 +195,21 @@ class CnnPolicy(TD3Policy): """ Policy class (with both actor and critic) for TD3. - :param observation_space: (gym.spaces.Space) Observation space - :param action_space: (gym.spaces.Space) Action space - :param lr_schedule: (Callable) Learning rate schedule (could be constant) - :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. - :param activation_fn: (Type[nn.Module]) Activation function - :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. - :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the feature extractor. - :param normalize_images: (bool) Whether to normalize images or not, + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default - :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param n_critics: (int) Number of critic networks to create. + :param n_critics: Number of critic networks to create. """ def __init__( diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 6784562c9..a1186ecc8 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -21,41 +21,41 @@ class TD3(OffPolicyAlgorithm): Paper: https://arxiv.org/abs/1802.09477 Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html - :param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) - :param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) - :param learning_rate: (float or callable) learning rate for adam optimizer, + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0) - :param buffer_size: (int) size of the replay buffer - :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts - :param batch_size: (int) Minibatch size for each gradient update - :param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) - :param gamma: (float) the discount factor - :param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. - :param gradient_steps: (int) How many gradient steps to do after each rollout + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq`` and ``n_episodes_rollout``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. - :param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. - :param action_noise: (ActionNoise) the action noise type (None by default), this can help + :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. - :param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps + :param policy_delay: Policy and target networks will only be updated once every policy_delay steps per training steps. The Q values will be updated policy_delay more often (update every training step). - :param target_policy_noise: (float) Standard deviation of Gaussian noise added to target policy + :param target_policy_noise: Standard deviation of Gaussian noise added to target policy (smoothing noise) - :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise. - :param create_eval_env: (bool) Whether to create a second environment that will be + :param target_noise_clip: Limit for absolute value of target policy smoothing noise. + :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - :param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug - :param seed: (int) Seed for the pseudo random generators - :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param _init_setup_model: Whether or not to build the network at the creation of the instance """ def __init__(