From 1288bc9f5ab9991bdcdd0c67674d32ea86bb4847 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Tue, 2 Apr 2024 11:58:10 +0200 Subject: [PATCH 1/9] Capitalise constants and remove redundant default description in buffer. --- tianshou/data/buffer/base.py | 22 +++++++++++----------- tianshou/data/buffer/manager.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 53f9bd8eb..0519d4353 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -21,16 +21,16 @@ class ReplayBuffer: :doc:`/01_tutorials/01_concepts`. :param size: the maximum size of replay buffer. - :param stack_num: the frame-stack sampling argument, should be greater than or - equal to 1. Default to 1 (no stacking). - :param ignore_obs_next: whether to not store obs_next. Default to False. + :param stack_num: the frame-stack sampling argument. It should be greater than or + equal to 1 (no stacking). + :param ignore_obs_next: whether to not store obs_next. :param save_only_last_obs: only save the last obs/obs_next when it has a shape - of (timestep, ...) because of temporal stacking. Default to False. + of (timestep, ...) because of temporal stacking. :param sample_avail: the parameter indicating sampling only available index - when using frame-stack sampling method. Default to False. + when using frame-stack sampling method. """ - _reserved_keys = ( + _RESERVED_KEYS = ( "obs", "act", "rew", @@ -41,7 +41,7 @@ class ReplayBuffer: "info", "policy", ) - _input_keys = ( + _INPUT_KEYS = ( "obs", "act", "rew", @@ -104,7 +104,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" + assert key not in self._RESERVED_KEYS, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) def save_hdf5(self, path: str, compression: str | None = None) -> None: @@ -162,7 +162,7 @@ def reset(self, keep_statistics: bool = False) -> None: def set_batch(self, batch: RolloutBatchProtocol) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" assert len(batch) == self.maxsize and set(batch.keys()).issubset( - self._reserved_keys, + self._RESERVED_KEYS, ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch @@ -352,7 +352,7 @@ def get( :param str key: the key to get, should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. - :param stack_num: Default to self.stack_num. + :param stack_num: Set to self.stack_num if set to None. """ if key not in self._meta and default_value is not None: return default_value @@ -415,6 +415,6 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat "policy": self.get(indices, "policy", Batch()), } for key in self._meta.__dict__: - if key not in self._input_keys: + if key not in self._INPUT_KEYS: batch_dict[key] = self._meta[key][indices] return cast(RolloutBatchProtocol, Batch(batch_dict)) diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index a495b0ada..fbff4a4af 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -127,7 +127,7 @@ def add( """ # preprocess batch new_batch = Batch() - for key in set(self._reserved_keys).intersection(batch.keys()): + for key in set(self._RESERVED_KEYS).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) From b5a0a98af5725637af8b5e2546e16213bf05f01b Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Tue, 2 Apr 2024 13:57:38 +0200 Subject: [PATCH 2/9] Refactor input validation of collector --- test/base/test_collector.py | 2 +- tianshou/data/collector.py | 91 +++++++++++++++++++++++++++---------- 2 files changed, 67 insertions(+), 26 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6bc1703f6..d9d818ae9 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -201,7 +201,7 @@ def test_collector() -> None: Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(TypeError): Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() # test NXEnv diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 751fedfb2..1d5209943 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -308,6 +308,60 @@ def _compute_action_policy_hidden( ) return act_RA, act_normalized_RA, policy_R, hidden_state_RH + def _validate_collect_input_and_get_ready_env_ids( + self, + n_episode: int | None, + n_step: int | None, + sample_equal_num_episodes_per_worker: bool, + ) -> np.ndarray: + """Check that exactly one of n_step or n_episode is specified. + Returns the idx of non-idle envs that will be used for the collection. + """ + if n_step is not None and n_episode is not None: + raise ValueError( + f"Only one of n_step or n_episode is allowed in Collector." + f"collect, got {n_step=}, {n_episode=}.", + ) + + if n_step is not None: + if sample_equal_num_episodes_per_worker: + raise ValueError( + "sample_equal_num_episodes_per_worker can only be used if `n_episode` is specified but" + "got `n_step` instead.", + ) + if n_step < 1: + raise ValueError(f"n_step should be an integer larger than 0, but got {n_step=}.") + + if n_step % self.env_num: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}). " + "This may cause extra transitions to be collected into the buffer.", + ) + return np.arange(self.env_num) + + elif n_episode is not None: + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) + if sample_equal_num_episodes_per_worker and n_episode % self.env_num != 0: + raise ValueError( + f"{n_episode=} must be a multiple of {self.env_num=} " + f"when using {sample_equal_num_episodes_per_worker=}.", + ) + return np.arange(min(self.env_num, n_episode)) + + else: + raise ValueError( + f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", + ) + # TODO: reduce complexity, remove the noqa def collect( self, @@ -318,6 +372,7 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + sample_equal_num_episodes_per_worker: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes. @@ -337,6 +392,8 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. + :param sample_equal_num_episodes_per_worker: whether to sample the same number + of episodes from each worker. Only used if n_episode is set. .. note:: @@ -364,31 +421,12 @@ def collect( # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." - if n_step is not None: - assert n_episode is None, ( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}." - ) - assert n_step > 0 - if n_step % self.env_num != 0: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}), " - "which may cause extra transitions being collected into the buffer.", - ) - ready_env_ids_R = np.arange(self.env_num) - elif n_episode is not None: - assert n_episode > 0 - if self.env_num > n_episode: - warnings.warn( - f"{n_episode=} should be larger than {self.env_num=} to " - f"collect at least one trajectory in each environment.", - ) - ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) + + ready_env_ids_R = self._validate_collect_input_and_get_ready_env_ids( + n_episode, + n_step, + sample_equal_num_episodes_per_worker=False, + ) start_time = time.time() @@ -668,6 +706,7 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + sample_equal_num_episodes_per_worker: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes with async env setting. @@ -689,6 +728,8 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :param sample_equal_num_episodes_per_worker: Not applicable to async collector. + #todo this is only used to keep the signatures of collect in Collector and AsyncCollector the same, maybe introduce some base class with collect as abstract method? .. note:: From a00b5424faabb3cae0a772bccc5cc22c55c1f2b8 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 11:04:46 +0200 Subject: [PATCH 3/9] WIP: base classes for collector --- tianshou/data/collector.py | 205 ++++++++++++++++++------------- tianshou/highlevel/agent.py | 4 +- tianshou/highlevel/experiment.py | 4 +- 3 files changed, 125 insertions(+), 88 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 751fedfb2..b2733678e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,8 +1,9 @@ import time import warnings +from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast +from typing import Any, Generic, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -122,39 +123,65 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: +TBuffer = TypeVar("TBuffer", bound=ReplayBuffer) + + +class CollectorBase(ABC, Generic[TBuffer]): + @abstractmethod + def get_buffer(self) -> TBuffer: + pass + + @abstractmethod + def reset(self) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + reset_before_collect: bool = True, + ) -> CollectStats: + pass + + +class _CollectorWithInit(ABC, Generic[TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. + :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` - as the default buffer. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. :param exploration_noise: determine whether the action needs to be modified - with the corresponding policy's exploration noise. If so, "policy. - exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action. Default to False. + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. .. note:: - Please make sure the given environment has a time limitation if using n_episode - collect option. + Please make sure the given environment has a time limitation if using n_episode + collect option. .. note:: - In past versions of Tianshou, the replay buffer passed to `__init__` - was automatically reset. This is not done in the current implementation. + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. """ def __init__( self, policy: BasePolicy, env: gym.Env | BaseVectorEnv, - buffer: ReplayBuffer | None = None, + buffer: TBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy @@ -174,6 +201,62 @@ def __init__( self._is_closed = False self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + def _compute_action_policy_hidden( + self, + random: bool, + ready_env_ids_R: np.ndarray, + use_grad: bool, + last_obs_RO: np.ndarray, + last_info_R: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, + ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: + """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" + if random: + try: + act_normalized_RA = np.array( + [self._action_space[i].sample() for i in ready_env_ids_R], + ) + # TODO: test whether envpool env explicitly + except TypeError: # envpool's action space is not for per-env + act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) + act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) + policy_R = Batch() + hidden_state_RH = None + + else: + info_batch = _HACKY_create_info_batch(last_info_R) + obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) + + with torch.set_grad_enabled(use_grad): + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) + + act_RA = to_numpy(act_batch_RA.act) + if self.exploration_noise: + act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.policy.map_action(act_RA) + + # TODO: cleanup the whole policy in batch thing + # todo policy_R can also be none, check + policy_R = act_batch_RA.get("policy", Batch()) + if not isinstance(policy_R, Batch): + raise RuntimeError( + f"The policy result should be a {Batch}, but got {type(policy_R)}", + ) + + hidden_state_RH = act_batch_RA.get("state", None) + # TODO: do we need the conditional? Would be better to just add hidden_state which could be None + if hidden_state_RH is not None: + policy_R.hidden_state = ( + hidden_state_RH # save state into buffer through policy attr + ) + return act_RA, act_normalized_RA, policy_R, hidden_state_RH + + def get_buffer(self) -> TBuffer: + return self.buffer + def close(self) -> None: """Close the collector and the environment.""" self.env.close() @@ -186,7 +269,8 @@ def is_closed(self) -> bool: """Return True if the collector is closed.""" return self._is_closed - def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: + # TODO: remove + def _assign_buffer(self, buffer: TBuffer | None) -> TBuffer: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) @@ -251,63 +335,29 @@ def reset_env( # this can happen if the env is an envpool env. Then the thing returned by reset is a dict # with array entries instead of an array of dicts # We use Batch to turn it into an array of dicts - self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] + self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts( + self._pre_collect_info_R, + ) # type: ignore[unreachable] self._pre_collect_hidden_state_RH = None - def _compute_action_policy_hidden( - self, - random: bool, - ready_env_ids_R: np.ndarray, - use_grad: bool, - last_obs_RO: np.ndarray, - last_info_R: np.ndarray, - last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, - ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: - """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" - if random: - try: - act_normalized_RA = np.array( - [self._action_space[i].sample() for i in ready_env_ids_R], - ) - # TODO: test whether envpool env explicitly - except TypeError: # envpool's action space is not for per-env - act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) - act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) - policy_R = Batch() - hidden_state_RH = None - - else: - info_batch = _HACKY_create_info_batch(last_info_R) - obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - - with torch.set_grad_enabled(use_grad): - act_batch_RA = self.policy( - obs_batch_R, - last_hidden_state_RH, - ) - - act_RA = to_numpy(act_batch_RA.act) - if self.exploration_noise: - act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) - act_normalized_RA = self.policy.map_action(act_RA) - - # TODO: cleanup the whole policy in batch thing - # todo policy_R can also be none, check - policy_R = act_batch_RA.get("policy", Batch()) - if not isinstance(policy_R, Batch): - raise RuntimeError( - f"The policy result should be a {Batch}, but got {type(policy_R)}", - ) + @staticmethod + def _reset_hidden_state_based_on_type( + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + # todo is this inplace magic and just working? - hidden_state_RH = act_batch_RA.get("state", None) - # TODO: do we need the conditional? Would be better to just add hidden_state which could be None - if hidden_state_RH is not None: - policy_R.hidden_state = ( - hidden_state_RH # save state into buffer through policy attr - ) - return act_RA, act_normalized_RA, policy_R, hidden_state_RH +class Collector(_CollectorWithInit[TBuffer], Generic[TBuffer]): # TODO: reduce complexity, remove the noqa def collect( self, @@ -580,23 +630,8 @@ def collect( collect_speed=step_count / collect_time, ) - def _reset_hidden_state_based_on_type( - self, - env_ind_local_D: np.ndarray, - last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, - ) -> None: - if isinstance(last_hidden_state_RH, torch.Tensor): - last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] - elif isinstance(last_hidden_state_RH, np.ndarray): - last_hidden_state_RH[env_ind_local_D] = ( - None if last_hidden_state_RH.dtype == object else 0 - ) - elif isinstance(last_hidden_state_RH, Batch): - last_hidden_state_RH.empty_(env_ind_local_D) - # todo is this inplace magic and just working? - -class AsyncCollector(Collector): +class AsyncCollector(_CollectorWithInit): """Async Collector handles async vector environment. The arguments are exactly the same as :class:`~tianshou.data.Collector`, please @@ -854,7 +889,9 @@ def collect( # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] + last_hidden_state_RH = copy( + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], + ) # type: ignore[index] if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f71a7f981..bc2fc96dc 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -119,8 +119,8 @@ def create_train_test_collector( save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, ) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, envs.test_envs) + train_collector: Collector[ReplayBuffer] = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector: Collector[ReplayBuffer] = Collector(policy, envs.test_envs) if reset_collectors: train_collector.reset() test_collector.reset() diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 17f0550fc..b70976122 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -9,7 +9,7 @@ import numpy as np import torch -from tianshou.data import Collector, InfoStats +from tianshou.data import Collector, InfoStats, ReplayBuffer, VectorReplayBuffer from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, @@ -310,7 +310,7 @@ def _watch_agent( render: float, ) -> None: policy.eval() - collector = Collector(policy, env) + collector: Collector[VectorReplayBuffer] = Collector(policy, env) result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy From 68620e9b27c489a88881aa3317b37342514d6f34 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Wed, 3 Apr 2024 23:23:01 +0200 Subject: [PATCH 4/9] Extract BaseCollectorClass --- tianshou/data/collector.py | 238 +++++++++++++++++++++---------------- 1 file changed, 134 insertions(+), 104 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1d5209943..285ea90c1 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,5 +1,6 @@ import time import warnings +from abc import abstractmethod from copy import copy from dataclasses import dataclass from typing import Any, Self, TypeVar, cast @@ -122,29 +123,29 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: +class BaseCollector: """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. + :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` - as the default buffer. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. :param exploration_noise: determine whether the action needs to be modified - with the corresponding policy's exploration noise. If so, "policy. - exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action. Default to False. + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. .. note:: - Please make sure the given environment has a time limitation if using n_episode - collect option. + Please make sure the given environment has a time limitation if using n_episode + collect option. .. note:: - In past versions of Tianshou, the replay buffer passed to `__init__` - was automatically reset. This is not done in the current implementation. + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. """ def __init__( @@ -154,7 +155,6 @@ def __init__( buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy @@ -210,27 +210,6 @@ def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: ) return buffer - def reset( - self, - reset_buffer: bool = True, - reset_stats: bool = True, - gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - """ - self.reset_env(gym_reset_kwargs=gym_reset_kwargs) - if reset_buffer: - self.reset_buffer() - if reset_stats: - self.reset_stat() - self._is_closed = False - def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 @@ -308,59 +287,88 @@ def _compute_action_policy_hidden( ) return act_RA, act_normalized_RA, policy_R, hidden_state_RH - def _validate_collect_input_and_get_ready_env_ids( + @staticmethod + def _reset_hidden_state_based_on_type( + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + + @abstractmethod + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + def reset( self, - n_episode: int | None, - n_step: int | None, - sample_equal_num_episodes_per_worker: bool, - ) -> np.ndarray: - """Check that exactly one of n_step or n_episode is specified. - Returns the idx of non-idle envs that will be used for the collection. + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> None: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) """ - if n_step is not None and n_episode is not None: - raise ValueError( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}.", - ) + self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + if reset_buffer: + self.reset_buffer() + if reset_stats: + self.reset_stat() + self._is_closed = False - if n_step is not None: - if sample_equal_num_episodes_per_worker: - raise ValueError( - "sample_equal_num_episodes_per_worker can only be used if `n_episode` is specified but" - "got `n_step` instead.", - ) - if n_step < 1: - raise ValueError(f"n_step should be an integer larger than 0, but got {n_step=}.") - if n_step % self.env_num: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}). " - "This may cause extra transitions to be collected into the buffer.", - ) - return np.arange(self.env_num) +class Collector(BaseCollector): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. - elif n_episode is not None: - if n_episode < 1: - raise ValueError( - f"{n_episode=} should be an integer larger than 0.", - ) - if n_episode < self.env_num: - warnings.warn( - f"{n_episode=} should be larger than or equal to {self.env_num=} " - f"(otherwise you will get idle workers and won't collect at" - f"least one trajectory in each env).", - ) - if sample_equal_num_episodes_per_worker and n_episode % self.env_num != 0: - raise ValueError( - f"{n_episode=} must be a multiple of {self.env_num=} " - f"when using {sample_equal_num_episodes_per_worker=}.", - ) - return np.arange(min(self.env_num, n_episode)) + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. - else: - raise ValueError( - f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", - ) + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + """ + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, exploration_noise) # TODO: reduce complexity, remove the noqa def collect( @@ -372,7 +380,6 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - sample_equal_num_episodes_per_worker: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes. @@ -392,8 +399,6 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. - :param sample_equal_num_episodes_per_worker: whether to sample the same number - of episodes from each worker. Only used if n_episode is set. .. note:: @@ -416,16 +421,59 @@ def collect( # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. + def _validate_collect_input_and_get_ready_env_ids( + n_episode: int | None, + n_step: int | None, + ) -> np.ndarray: + """Check that exactly one of n_step or n_episode is specified. + Returns the idx of non-idle envs that will be used for the collection. + """ + if n_step is not None and n_episode is not None: + raise ValueError( + f"Only one of n_step or n_episode is allowed in Collector." + f"collect, got {n_step=}, {n_episode=}.", + ) + + if n_step is not None: + if n_step < 1: + raise ValueError( + f"n_step should be an integer larger than 0, but got {n_step=}.", + ) + + if n_step % self.env_num: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}). " + "This may cause extra transitions to be collected into the buffer.", + ) + return np.arange(self.env_num) + + elif n_episode is not None: + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) + return np.arange(min(self.env_num, n_episode)) + + else: + raise ValueError( + f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", + ) + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." - ready_env_ids_R = self._validate_collect_input_and_get_ready_env_ids( + ready_env_ids_R = _validate_collect_input_and_get_ready_env_ids( n_episode, n_step, - sample_equal_num_episodes_per_worker=False, ) start_time = time.time() @@ -618,23 +666,8 @@ def collect( collect_speed=step_count / collect_time, ) - def _reset_hidden_state_based_on_type( - self, - env_ind_local_D: np.ndarray, - last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, - ) -> None: - if isinstance(last_hidden_state_RH, torch.Tensor): - last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] - elif isinstance(last_hidden_state_RH, np.ndarray): - last_hidden_state_RH[env_ind_local_D] = ( - None if last_hidden_state_RH.dtype == object else 0 - ) - elif isinstance(last_hidden_state_RH, Batch): - last_hidden_state_RH.empty_(env_ind_local_D) - # todo is this inplace magic and just working? - -class AsyncCollector(Collector): +class AsyncCollector(BaseCollector): """Async Collector handles async vector environment. The arguments are exactly the same as :class:`~tianshou.data.Collector`, please @@ -706,7 +739,6 @@ def collect( no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - sample_equal_num_episodes_per_worker: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes with async env setting. @@ -728,8 +760,6 @@ def collect( (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) - :param sample_equal_num_episodes_per_worker: Not applicable to async collector. - #todo this is only used to keep the signatures of collect in Collector and AsyncCollector the same, maybe introduce some base class with collect as abstract method? .. note:: From e4d0cb56cd204d17a6f838c9ee3a545c956c6d7a Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Fri, 5 Apr 2024 15:07:27 +0200 Subject: [PATCH 5/9] Introduce Base and Stump class via branch merge --- tianshou/data/collector.py | 173 ++++++++++++------------------- tianshou/highlevel/agent.py | 7 +- tianshou/highlevel/experiment.py | 2 +- 3 files changed, 76 insertions(+), 106 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b2707bee3..b3f9504ef 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,7 +12,6 @@ from tianshou.data import ( Batch, CachedReplayBuffer, - PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, @@ -123,12 +122,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -TBuffer = TypeVar("TBuffer", bound=ReplayBuffer) +_TBuffer = TypeVar("_TBuffer", bound=ReplayBuffer) -class CollectorBase(ABC, Generic[TBuffer]): +class CollectorBase(ABC, Generic[_TBuffer]): @abstractmethod - def get_buffer(self) -> TBuffer: + def get_buffer(self) -> _TBuffer: pass @abstractmethod @@ -145,12 +144,15 @@ def collect( n_step: int | None = None, n_episode: int | None = None, random: bool = False, - reset_before_collect: bool = True, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: pass -class _CollectorWithInit(ABC, Generic[TBuffer]): +class _CollectorStump(ABC, Generic[_TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -179,7 +181,7 @@ def __init__( self, policy: BasePolicy, env: gym.Env | BaseVectorEnv, - buffer: TBuffer | None = None, + buffer: _TBuffer | None = None, exploration_noise: bool = False, ) -> None: if isinstance(env, gym.Env) and not hasattr(env, "__len__"): @@ -214,7 +216,7 @@ def is_closed(self) -> bool: return self._is_closed # TODO: remove - def _assign_buffer(self, buffer: TBuffer | None) -> TBuffer: + def _assign_buffer(self, buffer: _TBuffer | None) -> _TBuffer: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) @@ -225,16 +227,9 @@ def _assign_buffer(self, buffer: TBuffer | None) -> TBuffer: else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: - if isinstance(buffer, ReplayBuffer): - buffer_type = "ReplayBuffer" - vector_type = "VectorReplayBuffer" - if isinstance(buffer, PrioritizedReplayBuffer): - buffer_type = "PrioritizedReplayBuffer" - vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " - f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" - f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", + f"Cannot use {buffer.__class__.__name__}(size={buffer.maxsize}, ...) to collect " + f"{self.env_num} envs, please use a corresponding vectorized buffer instead.", ) return buffer @@ -329,19 +324,6 @@ def _reset_hidden_state_based_on_type( elif isinstance(last_hidden_state_RH, Batch): last_hidden_state_RH.empty_(env_ind_local_D) - @abstractmethod - def collect( - self, - n_step: int | None = None, - n_episode: int | None = None, - random: bool = False, - render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, - gym_reset_kwargs: dict[str, Any] | None = None, - ) -> CollectStats: - pass - def reset( self, reset_buffer: bool = True, @@ -363,8 +345,51 @@ def reset( self.reset_stat() self._is_closed = False + def _validate_collect_input( + self, + n_episode: int | None, + n_step: int | None, + ) -> None: + """Check that exactly one of n_step or n_episode is specified. + Returns the idx of non-idle envs that will be used for the collection. + """ + if n_step is not None and n_episode is not None: + raise ValueError( + f"Only one of n_step or n_episode is allowed in Collector." + f"collect, got {n_step=}, {n_episode=}.", + ) + + if n_step is not None: + if n_step < 1: + raise ValueError( + f"n_step should be an integer larger than 0, but got {n_step=}.", + ) + + if n_step % self.env_num: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}). " + "This may cause extra transitions to be collected into the buffer.", + ) + + elif n_episode is not None: + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) + + else: + raise ValueError( + f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", + ) + -class Collector(_CollectorWithInit[TBuffer], Generic[TBuffer]): +class Collector(_CollectorStump[_TBuffer], Generic[_TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -393,13 +418,11 @@ def __init__( self, policy: BasePolicy, env: gym.Env | BaseVectorEnv, - buffer: ReplayBuffer | None = None, + buffer: _TBuffer | None = None, exploration_noise: bool = False, ) -> None: super().__init__(policy, env, buffer, exploration_noise) -class Collector(_CollectorWithInit[TBuffer], Generic[TBuffer]): - # TODO: reduce complexity, remove the noqa def collect( self, n_step: int | None = None, @@ -450,61 +473,18 @@ def collect( # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. - def _validate_collect_input_and_get_ready_env_ids( - n_episode: int | None, - n_step: int | None, - ) -> np.ndarray: - """Check that exactly one of n_step or n_episode is specified. - Returns the idx of non-idle envs that will be used for the collection. - """ - if n_step is not None and n_episode is not None: - raise ValueError( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}.", - ) - - if n_step is not None: - if n_step < 1: - raise ValueError( - f"n_step should be an integer larger than 0, but got {n_step=}.", - ) - - if n_step % self.env_num: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}). " - "This may cause extra transitions to be collected into the buffer.", - ) - return np.arange(self.env_num) - - elif n_episode is not None: - if n_episode < 1: - raise ValueError( - f"{n_episode=} should be an integer larger than 0.", - ) - if n_episode < self.env_num: - warnings.warn( - f"{n_episode=} should be larger than or equal to {self.env_num=} " - f"(otherwise you will get idle workers and won't collect at" - f"least one trajectory in each env).", - ) - return np.arange(min(self.env_num, n_episode)) + # Input validation + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + self._validate_collect_input(n_episode, n_step) - else: - raise ValueError( - f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", - ) + if n_episode: + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: # n_step case + ready_env_ids_R = np.arange(self.env_num) use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} - # Input validation - assert not self.env.is_async, "Please use AsyncCollector if using async venv." - - ready_env_ids_R = _validate_collect_input_and_get_ready_env_ids( - n_episode, - n_step, - ) - start_time = time.time() if reset_before_collect: @@ -696,7 +676,7 @@ def _validate_collect_input_and_get_ready_env_ids( ) -class AsyncCollector(_CollectorWithInit): +class AsyncCollector(_CollectorStump[_TBuffer], Generic[_TBuffer]): """Async Collector handles async vector environment. The arguments are exactly the same as :class:`~tianshou.data.Collector`, please @@ -707,7 +687,7 @@ def __init__( self, policy: BasePolicy, env: BaseVectorEnv, - buffer: ReplayBuffer | None = None, + buffer: _TBuffer | None = None, exploration_noise: bool = False, ) -> None: # assert env.is_async @@ -797,24 +777,11 @@ def collect( :return: A dataclass object """ + self._validate_collect_input(n_episode, n_step) + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} - # collect at least n_step or n_episode - if n_step is not None: - assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." - ) - assert n_step > 0 - elif n_episode is not None: - assert n_episode > 0 - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - if reset_before_collect: # first we need to step all envs to be able to interact with them if self.env.waiting_id: @@ -954,9 +921,7 @@ def collect( # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy( - self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], - ) # type: ignore[index] + last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index bc2fc96dc..c394f4b72 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -119,7 +119,12 @@ def create_train_test_collector( save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, ) - train_collector: Collector[ReplayBuffer] = Collector(policy, train_envs, buffer, exploration_noise=True) + train_collector: Collector[ReplayBuffer] = Collector( + policy, + train_envs, + buffer, + exploration_noise=True, + ) test_collector: Collector[ReplayBuffer] = Collector(policy, envs.test_envs) if reset_collectors: train_collector.reset() diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b70976122..b5df56bfe 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -9,7 +9,7 @@ import numpy as np import torch -from tianshou.data import Collector, InfoStats, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, InfoStats, VectorReplayBuffer from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, From da3da8d420e7cbd6aa2dc405d45647baf753846b Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Thu, 18 Apr 2024 17:36:59 +0200 Subject: [PATCH 6/9] Test and draft for collector first k episodes and equal across all envs --- test/base/test_collector.py | 28 ++ tianshou/data/collector.py | 531 ++++++++++++++++++++++++++++++++++++ 2 files changed, 559 insertions(+) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d9d818ae9..0d8a074e2 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -12,6 +12,10 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.data.collector import ( + Collector_Equal_Num_Episodes_Per_Env, + Collector_First_K_Episodes, +) from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy @@ -212,6 +216,30 @@ def test_collector() -> None: c_suproc_new.collect(n_step=6) assert c_suproc_new.buffer.obs.dtype == object +def test_sample_first_k_episodes_collector() -> None: + env_lens = [1,8,9,10] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] + venv = SubprocVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = Collector_First_K_Episodes(policy, + venv, + VectorReplayBuffer(total_size=100, buffer_num=4)) + c1.reset() + res = c1.collect(n_episode= 12) + assert np.array_equal(res.lens, np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1, 10])) + +def test_Collector_Equal_Num_Episodes_Per_Env() -> None: + env_lens = [1, 8, 9, 10] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] + venv = SubprocVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = Collector_Equal_Num_Episodes_Per_Env(policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4)) + c1.reset() + with pytest.raises(ValueError) as excinfo: + c1.collect(n_episode=9) + assert str(excinfo.value) == "n_episode has to be a multiple of the number of envs, but got n_episode=9, self.env_num=4." + res = c1.collect(n_episode=12) + assert np.array_equal(res.lens, np.array([1, 8, 9, 10, 1, 8, 9, 10, 1, 8, 9, 10])) @pytest.fixture() def get_AsyncCollector(): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b3f9504ef..b185e2996 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -676,6 +676,537 @@ def collect( ) + + +class Collector_First_K_Episodes(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + """ + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: _TBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, exploration_noise) + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + + # Input validation + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + self._validate_collect_input(n_episode, n_step) + + #we this is really only applicable in the n_episode case, maybe refactor + # to some keyword or froce the user to explicitly specify n_episode or n_step + + if n_episode: + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: # n_step case + ready_env_ids_R = np.arange(self.env_num) + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + start_time = time.time() + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + + while True: + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) + # restore the state: if the last state is None, it won't store + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs + if render: + self.env.render() + if not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + if n_step: + # persist for future collect iterations + self._pre_collect_obs_RO = last_obs_RO + self._pre_collect_info_R = last_info_R + self._pre_collect_hidden_state_RH = last_hidden_state_RH + elif n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) + +class Collector_Equal_Num_Episodes_Per_Env(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + + This is not the most efficient implementation possible. The collector + collects one episode in every environment, thus a environment having reached + done is staying idle until the "slowest" is reaching done. This could be + optimized by explicitly trcking how many trajectories have been collected in each + environment. + """ + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: _TBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, exploration_noise) + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before + collecting data. + It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + # Input validation + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + self._validate_collect_input(n_episode, n_step) + if not n_episode % self.env_num == 0: + raise ValueError( + f"n_episode has to be a multiple of the number of envs, but got {n_episode=}, {self.env_num=}.", + ) + if n_episode: + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: # n_step case #todo not needed remove + ready_env_ids_R = np.arange(self.env_num) + + non_idle_env_ids = np.arange(self.env_num) + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + start_time = time.time() + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + + while True: + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) + # restore the state: if the last state is None, it won't store + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + is_done_and_not_idle = done_R + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs + if render: + self.env.render() + if not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + # now we copy obs_next to obs, but since there might be + + # we make sure to now only step the envs that aren't done and reset only once all episodes have reached done + + + # finished episodes, we have to reset finished envs first. + + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. + if n_episode: + remaining_episodes_to_collect = n_episode - num_collected_episodes + surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect + if surplus_env_num > 0: + # R becomes R-S here, preparing for the next iteration in while loop + # Everything that was of length R needs to be filtered and become of length R-S. + # Note that this won't be the last iteration, as one iteration equals one + # step and we still need to collect the remaining episodes to reach the breaking condition. + + # creating the mask + env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) + env_should_remain_R[env_to_be_ignored_ind_local_S] = False + # stripping the "idle" indices, shortening the relevant quantities from R to R-S + ready_env_ids_R = ready_env_ids_R[env_should_remain_R] + last_obs_RO = last_obs_RO[env_should_remain_R] + last_info_R = last_info_R[env_should_remain_R] + if hidden_state_RH is not None: + last_hidden_state_RH = last_hidden_state_RH[ + env_should_remain_R] # type: ignore[index] + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + if n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) + + class AsyncCollector(_CollectorStump[_TBuffer], Generic[_TBuffer]): """Async Collector handles async vector environment. From bc003bf18ad4ff1f2225663d6a769fee9aabbe62 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Mon, 22 Apr 2024 19:15:34 +0200 Subject: [PATCH 7/9] Collect same number of episodes in each worker. Cleanup sync collector input checking --- test/base/test_collector.py | 24 +++-- tianshou/data/collector.py | 200 +++++++++++++++++------------------- 2 files changed, 112 insertions(+), 112 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 0d8a074e2..2361b0009 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -216,31 +216,39 @@ def test_collector() -> None: c_suproc_new.collect(n_step=6) assert c_suproc_new.buffer.obs.dtype == object + def test_sample_first_k_episodes_collector() -> None: - env_lens = [1,8,9,10] + env_lens = [1, 8, 9, 10] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] venv = SubprocVectorEnv(env_fns) policy = MaxActionPolicy() - c1 = Collector_First_K_Episodes(policy, - venv, - VectorReplayBuffer(total_size=100, buffer_num=4)) + c1 = Collector_First_K_Episodes(policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4)) c1.reset() - res = c1.collect(n_episode= 12) - assert np.array_equal(res.lens, np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1, 10])) + res = c1.collect(n_episode=12) + assert np.array_equal(res.lens, np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1, 10])) + def test_Collector_Equal_Num_Episodes_Per_Env() -> None: env_lens = [1, 8, 9, 10] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] venv = SubprocVectorEnv(env_fns) policy = MaxActionPolicy() - c1 = Collector_Equal_Num_Episodes_Per_Env(policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4)) + c1 = Collector_Equal_Num_Episodes_Per_Env( + policy, + venv, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) c1.reset() with pytest.raises(ValueError) as excinfo: c1.collect(n_episode=9) - assert str(excinfo.value) == "n_episode has to be a multiple of the number of envs, but got n_episode=9, self.env_num=4." + assert ( + str(excinfo.value) + == "n_episode has to be a multiple of the number of envs, but got n_episode=9, self.env_num=4." + ) res = c1.collect(n_episode=12) assert np.array_equal(res.lens, np.array([1, 8, 9, 10, 1, 8, 9, 10, 1, 8, 9, 10])) + @pytest.fixture() def get_AsyncCollector(): env_lens = [2, 3, 4, 5] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b185e2996..edff72cfc 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -313,7 +313,10 @@ def _compute_action_policy_hidden( @staticmethod def _reset_hidden_state_based_on_type( env_ind_local_D: np.ndarray, - last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + last_hidden_state_RH: np.ndarray + | torch.Tensor + | Batch + | None, # todo care only about the type not the data in it ) -> None: if isinstance(last_hidden_state_RH, torch.Tensor): last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] @@ -587,30 +590,33 @@ def collect( # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration # Resetting envs that reached done, or removing some of them from the collection if needed (see below) - if num_episodes_done_this_iter > 0: + if np.any(done_R): # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above - env_ind_local_D = np.where(done_R)[0] - env_ind_global_D = ready_env_ids_R[env_ind_local_D] - episode_lens.extend(ep_len_R[env_ind_local_D]) - episode_returns.extend(ep_rew_R[env_ind_local_D]) - episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + local_ready_env_ids_done_D = np.where(done_R)[0] + global_ready_env_ids_done_D = ready_env_ids_R[local_ready_env_ids_done_D] + episode_lens.extend(ep_len_R[local_ready_env_ids_done_D]) + episode_returns.extend(ep_rew_R[local_ready_env_ids_done_D]) + episode_start_indices.extend(ep_idx_R[local_ready_env_ids_done_D]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. obs_reset_DO, info_reset_D = self.env.reset( - env_id=env_ind_global_D, + env_id=global_ready_env_ids_done_D, **gym_reset_kwargs, ) # Set the hidden state to zero or None for the envs that reached done # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of # this complex logic - self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + self._reset_hidden_state_based_on_type( + local_ready_env_ids_done_D, + last_hidden_state_RH, + ) # preparing for the next iteration - last_obs_RO[env_ind_local_D] = obs_reset_DO - last_info_R[env_ind_local_D] = info_reset_D + last_obs_RO[local_ready_env_ids_done_D] = obs_reset_DO + last_info_R[local_ready_env_ids_done_D] = info_reset_D # Handling the case when we have more ready envs than desired and are not done yet # @@ -636,7 +642,7 @@ def collect( # step and we still need to collect the remaining episodes to reach the breaking condition. # creating the mask - env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_to_be_ignored_ind_local_S = local_ready_env_ids_done_D[:surplus_env_num] env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) env_should_remain_R[env_to_be_ignored_ind_local_S] = False # stripping the "idle" indices, shortening the relevant quantities from R to R-S @@ -676,8 +682,6 @@ def collect( ) - - class Collector_First_K_Episodes(_CollectorStump[_TBuffer], Generic[_TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -722,18 +726,25 @@ def collect( reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." - self._validate_collect_input(n_episode, n_step) - - #we this is really only applicable in the n_episode case, maybe refactor - # to some keyword or froce the user to explicitly specify n_episode or n_step + if n_step: + raise ValueError( + f"First_K_Episodes collector only supports n_episode, but got {n_step=}.", + ) + assert n_episode, "n_episode should be specified." + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) - if n_episode: - ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: # n_step case - ready_env_ids_R = np.arange(self.env_num) + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -908,6 +919,7 @@ def collect( collect_speed=step_count / collect_time, ) + class Collector_Equal_Num_Episodes_Per_Env(_CollectorStump[_TBuffer], Generic[_TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -933,9 +945,10 @@ class Collector_Equal_Num_Episodes_Per_Env(_CollectorStump[_TBuffer], Generic[_T was automatically reset. This is not done in the current implementation. This is not the most efficient implementation possible. The collector - collects one episode in every environment, thus a environment having reached - done is staying idle until the "slowest" is reaching done. This could be - optimized by explicitly trcking how many trajectories have been collected in each + collects one episode in every environment, thus an environment having reached + done is staying idle until the "slowest" is reaching done. Non-idle envs are referred to as ready. + This could be + optimized by explicitly tracking how many trajectories have been collected in each environment. """ @@ -949,14 +962,14 @@ def __init__( super().__init__(policy, env, buffer, exploration_noise) def collect( - self, - n_step: int | None = None, - n_episode: int | None = None, - random: bool = False, - render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, - gym_reset_kwargs: dict[str, Any] | None = None, + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: """Collect a specified number of steps or episodes. @@ -1000,17 +1013,21 @@ def collect( # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." - self._validate_collect_input(n_episode, n_step) + if n_step: + raise ValueError( + f"First_K_Episodes collector only supports n_episode, but got {n_step=}.", + ) + assert n_episode is not None, "n_episode should be specified." # needed for mypy + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if not n_episode % self.env_num == 0: raise ValueError( f"n_episode has to be a multiple of the number of envs, but got {n_episode=}, {self.env_num=}.", ) - if n_episode: - ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: # n_step case #todo not needed remove - ready_env_ids_R = np.arange(self.env_num) - - non_idle_env_ids = np.arange(self.env_num) + ready_env_ids_R = np.arange(self.env_num) use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -1075,7 +1092,6 @@ def collect( # This can happen if the env is an envpool env. Then the info returned by step is a dict info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] done_R = np.logical_or(terminated_R, truncated_R) - is_done_and_not_idle = done_R current_iteration_batch = cast( RolloutBatchProtocol, @@ -1118,72 +1134,48 @@ def collect( # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration # Resetting envs that reached done, or removing some of them from the collection if needed (see below) - if num_episodes_done_this_iter > 0: + if np.any(done_R): # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above - env_ind_local_D = np.where(done_R)[0] - env_ind_global_D = ready_env_ids_R[env_ind_local_D] - episode_lens.extend(ep_len_R[env_ind_local_D]) - episode_returns.extend(ep_rew_R[env_ind_local_D]) - episode_start_indices.extend(ep_idx_R[env_ind_local_D]) - # now we copy obs_next to obs, but since there might be - - # we make sure to now only step the envs that aren't done and reset only once all episodes have reached done - - - # finished episodes, we have to reset finished envs first. - - obs_reset_DO, info_reset_D = self.env.reset( - env_id=env_ind_global_D, - **gym_reset_kwargs, - ) - - # Set the hidden state to zero or None for the envs that reached done - # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of - # this complex logic - self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) - - # preparing for the next iteration - last_obs_RO[env_ind_local_D] = obs_reset_DO - last_info_R[env_ind_local_D] = info_reset_D - - # Handling the case when we have more ready envs than desired and are not done yet - # - # This can only happen if we are collecting a fixed number of episodes - # If we have more ready envs than there are remaining episodes to collect, - # we will remove some of them for the next rollout - # One effect of this is the following: only envs that have completed an episode - # in the last step can ever be removed from the ready envs. - # Thus, this guarantees that each env will contribute at least one episode to the - # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" - # However, it is not at all clear whether this is actually useful or necessary. - # Additional naming convention: - # S - number of surplus envs - # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. - # Changing R to R-S highly increases the complexity of the code. - if n_episode: - remaining_episodes_to_collect = n_episode - num_collected_episodes - surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect - if surplus_env_num > 0: - # R becomes R-S here, preparing for the next iteration in while loop - # Everything that was of length R needs to be filtered and become of length R-S. - # Note that this won't be the last iteration, as one iteration equals one - # step and we still need to collect the remaining episodes to reach the breaking condition. + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + local_ready_env_ids_done_D = np.where(done_R)[0] + episode_lens.extend(ep_len_R[local_ready_env_ids_done_D]) + episode_returns.extend(ep_rew_R[local_ready_env_ids_done_D]) + episode_start_indices.extend(ep_idx_R[local_ready_env_ids_done_D]) + + env_array_ids_non_idle_not_done_R = np.where(~done_R)[0] + # new R, all envs that haven't reached done yet are remain ready + ready_env_ids_R = ready_env_ids_R[env_array_ids_non_idle_not_done_R] + last_obs_RO = last_obs_RO[env_array_ids_non_idle_not_done_R] + last_info_R = last_info_R[env_array_ids_non_idle_not_done_R] + last_hidden_state_RH = last_hidden_state_RH[env_array_ids_non_idle_not_done_R] # type: ignore[index] + + if len(ready_env_ids_R) == 0: + ready_env_ids_R = np.arange(self.env_num) # so now R == E again + + obs_reset_RO, info_reset_R = self.env.reset( + env_id=ready_env_ids_R, + **gym_reset_kwargs, + ) - # creating the mask - env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] - env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) - env_should_remain_R[env_to_be_ignored_ind_local_S] = False - # stripping the "idle" indices, shortening the relevant quantities from R to R-S - ready_env_ids_R = ready_env_ids_R[env_should_remain_R] - last_obs_RO = last_obs_RO[env_should_remain_R] - last_info_R = last_info_R[env_should_remain_R] - if hidden_state_RH is not None: - last_hidden_state_RH = last_hidden_state_RH[ - env_should_remain_R] # type: ignore[index] + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type( + ready_env_ids_R, + self._pre_collect_hidden_state_RH, + ) # just needed to infer the type + + # preparing for the next iteration + last_obs_RO = obs_reset_RO + last_info_R = info_reset_R + last_hidden_state_RH = ( + self._pre_collect_hidden_state_RH + ) # todo what is actually correct, reset_env if (n_step and step_count >= n_step) or ( - n_episode and num_collected_episodes >= n_episode + n_episode and num_collected_episodes >= n_episode ): break @@ -1314,7 +1306,7 @@ def collect( gym_reset_kwargs = gym_reset_kwargs or {} if reset_before_collect: - # first we need to step all envs to be able to interact with them + # First we need to step all envs to be able to interact with them if self.env.waiting_id: self.env.step(None, id=self.env.waiting_id) self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) From ec75df3250ec3c71517d48187e6cdb59b6b68526 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Tue, 23 Apr 2024 14:21:08 +0200 Subject: [PATCH 8/9] Refactor pointer grammer in buffer and buffer/manager --- tianshou/data/buffer/base.py | 40 ++++++++++++++++++--------------- tianshou/data/buffer/manager.py | 35 ++++++++++++++++------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 0519d4353..e4d258380 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -41,6 +41,9 @@ class ReplayBuffer: "info", "policy", ) + + _REQUIRED_KEYS = frozenset({"obs", "act", "rew", "terminated", "truncated", "done"}) + _INPUT_KEYS = ( "obs", "act", @@ -181,7 +184,7 @@ def prev(self, index: int | np.ndarray) -> np.ndarray: return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of next transition. + """Return the index of the next transition. The index won't be modified if it is the end of an episode. """ @@ -189,9 +192,9 @@ def next(self, index: int | np.ndarray) -> np.ndarray: return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: - """Move the data from the given buffer to current buffer. + """Move the data from the given buffer to the current buffer. - Return the updated indices. If update fails, return an empty array. + Return the updated indices. If the update fails, return an empty array. """ if len(buffer) == 0 or self.maxsize == 0: return np.array([], int) @@ -212,17 +215,17 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( + def _update_buffer_state_after_adding_batch( self, rew: float | np.ndarray, done: bool, ) -> tuple[int, float | np.ndarray, int, int]: """Maintain the buffer's state after adding one data batch. - Return (index_to_be_modified, episode_reward, episode_length, + Return (index_to_add_at, episode_reward, episode_length, episode_start_index). """ - self.last_index[0] = ptr = self._index + self.last_index[0] = index_to_add_at = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -230,17 +233,17 @@ def _add_index( self._ep_len += 1 if done: - result = ptr, self._ep_rew, self._ep_len, self._ep_idx + result = index_to_add_at, self._ep_rew, self._ep_len, self._ep_idx self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index return result - return ptr, self._ep_rew * 0.0, 0, self._ep_idx + return index_to_add_at, self._ep_rew * 0.0, 0, self._ep_idx def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Add a batch of data into replay buffer. + """Add a batch of data into the replay buffer. :param batch: the input data batch. "obs", "act", "rew", "terminated", "truncated" are required keys. @@ -257,9 +260,10 @@ def add( new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( - batch.keys(), - ) # important to do after preprocess batch + if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()): + raise RuntimeError( + f"The input batch you try to add is missing the keys {missing_keys}.", + ) # important to do after batch preprocessing stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 @@ -269,14 +273,14 @@ def add( batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] - # get ptr + # get ep_add_at_idx if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done)) + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = (np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done)) try: - self._meta[ptr] = batch + self._meta[ep_add_at_idx] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) @@ -287,8 +291,8 @@ def add( self._meta = create_value(batch, self.maxsize, stack) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) - self._meta[ptr] = batch - return ptr, ep_rew, ep_len, ep_idx + self._meta[ep_add_at_idx] = batch + return ep_add_at_idx, ep_rew, ep_len, ep_start_idx def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -349,7 +353,7 @@ def get( stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. - :param str key: the key to get, should be one of the reserved_keys. + :param str key: the key to get. Should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. :param stack_num: Set to self.stack_num if set to None. diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index fbff4a4af..73e45adf9 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -15,9 +15,9 @@ class ReplayBufferManager(ReplayBuffer): These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. - :param buffer_list: a list of ReplayBuffer needed to be handled. + :param buffer_list: a list of ReplayBuffer objects needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ @@ -118,8 +118,8 @@ def add( ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. - Each of the data's length (first dimension) must equal to the length of - buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + Each of the data's lengths (first dimension) must be equal to the length of + buffer_ids. By default, buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and @@ -131,7 +131,10 @@ def add( new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys()) + if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()): + raise RuntimeError( + f"The input batch you try to add is missing the keys {missing_keys}.", + ) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: @@ -141,21 +144,21 @@ def add( # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] + ep_add_at_idxs, ep_lens, ep_rews, ep_start_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[buffer_id]._update_buffer_state_after_adding_batch( batch.rew[batch_idx], batch.done[batch_idx], ) - ptrs.append(ptr + self._offset[buffer_id]) + ep_add_at_idxs.append(ep_add_at_idx + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) - ep_idxs.append(ep_idx + self._offset[buffer_id]) - self.last_index[buffer_id] = ptr + self._offset[buffer_id] + ep_start_idxs.append(ep_start_idx + self._offset[buffer_id]) + self.last_index[buffer_id] = ep_add_at_idx + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) - ptrs = np.array(ptrs) + ep_add_at_idxs = np.array(ep_add_at_idxs) try: - self._meta[ptrs] = batch + self._meta[ep_add_at_idxs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) @@ -166,8 +169,8 @@ def add( else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() - self._meta[ptrs] = batch - return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) + self._meta[ep_add_at_idxs] = batch + return ep_add_at_idxs, np.array(ep_rews), np.array(ep_lens), np.array(ep_start_idxs) def sample_indices(self, batch_size: int | None) -> np.ndarray: # TODO: simplify this code @@ -214,7 +217,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ @@ -235,7 +238,7 @@ class HERReplayBufferManager(ReplayBufferManager): :param buffer_list: a list of HERReplayBuffer needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ From 7dcf7c335856eeeaa7f451394779217e2a91dbc8 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Tue, 23 Apr 2024 14:21:43 +0200 Subject: [PATCH 9/9] Refactor pointer grammer in buffer and buffer/manager --- tianshou/data/buffer/base.py | 4 +++- tianshou/data/buffer/manager.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index e4d258380..07b7ecea4 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -278,7 +278,9 @@ def add( rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ep_add_at_idx, ep_rew, ep_len, ep_start_idx = (np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done)) + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = ( + np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done) + ) try: self._meta[ep_add_at_idx] = batch except ValueError: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 73e45adf9..3099a2451 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -146,7 +146,9 @@ def add( buffer_ids = np.arange(self.buffer_num) ep_add_at_idxs, ep_lens, ep_rews, ep_start_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[buffer_id]._update_buffer_state_after_adding_batch( + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[ + buffer_id + ]._update_buffer_state_after_adding_batch( batch.rew[batch_idx], batch.done[batch_idx], )