diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6bc1703f6..2361b0009 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -12,6 +12,10 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.data.collector import ( + Collector_Equal_Num_Episodes_Per_Env, + Collector_First_K_Episodes, +) from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy @@ -201,7 +205,7 @@ def test_collector() -> None: Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(TypeError): Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() # test NXEnv @@ -213,6 +217,38 @@ def test_collector() -> None: assert c_suproc_new.buffer.obs.dtype == object +def test_sample_first_k_episodes_collector() -> None: + env_lens = [1, 8, 9, 10] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] + venv = SubprocVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = Collector_First_K_Episodes(policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4)) + c1.reset() + res = c1.collect(n_episode=12) + assert np.array_equal(res.lens, np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1, 10])) + + +def test_Collector_Equal_Num_Episodes_Per_Env() -> None: + env_lens = [1, 8, 9, 10] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens] + venv = SubprocVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = Collector_Equal_Num_Episodes_Per_Env( + policy, + venv, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) + c1.reset() + with pytest.raises(ValueError) as excinfo: + c1.collect(n_episode=9) + assert ( + str(excinfo.value) + == "n_episode has to be a multiple of the number of envs, but got n_episode=9, self.env_num=4." + ) + res = c1.collect(n_episode=12) + assert np.array_equal(res.lens, np.array([1, 8, 9, 10, 1, 8, 9, 10, 1, 8, 9, 10])) + + @pytest.fixture() def get_AsyncCollector(): env_lens = [2, 3, 4, 5] diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 53f9bd8eb..07b7ecea4 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -21,16 +21,16 @@ class ReplayBuffer: :doc:`/01_tutorials/01_concepts`. :param size: the maximum size of replay buffer. - :param stack_num: the frame-stack sampling argument, should be greater than or - equal to 1. Default to 1 (no stacking). - :param ignore_obs_next: whether to not store obs_next. Default to False. + :param stack_num: the frame-stack sampling argument. It should be greater than or + equal to 1 (no stacking). + :param ignore_obs_next: whether to not store obs_next. :param save_only_last_obs: only save the last obs/obs_next when it has a shape - of (timestep, ...) because of temporal stacking. Default to False. + of (timestep, ...) because of temporal stacking. :param sample_avail: the parameter indicating sampling only available index - when using frame-stack sampling method. Default to False. + when using frame-stack sampling method. """ - _reserved_keys = ( + _RESERVED_KEYS = ( "obs", "act", "rew", @@ -41,7 +41,10 @@ class ReplayBuffer: "info", "policy", ) - _input_keys = ( + + _REQUIRED_KEYS = frozenset({"obs", "act", "rew", "terminated", "truncated", "done"}) + + _INPUT_KEYS = ( "obs", "act", "rew", @@ -104,7 +107,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" + assert key not in self._RESERVED_KEYS, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) def save_hdf5(self, path: str, compression: str | None = None) -> None: @@ -162,7 +165,7 @@ def reset(self, keep_statistics: bool = False) -> None: def set_batch(self, batch: RolloutBatchProtocol) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" assert len(batch) == self.maxsize and set(batch.keys()).issubset( - self._reserved_keys, + self._RESERVED_KEYS, ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch @@ -181,7 +184,7 @@ def prev(self, index: int | np.ndarray) -> np.ndarray: return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of next transition. + """Return the index of the next transition. The index won't be modified if it is the end of an episode. """ @@ -189,9 +192,9 @@ def next(self, index: int | np.ndarray) -> np.ndarray: return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: - """Move the data from the given buffer to current buffer. + """Move the data from the given buffer to the current buffer. - Return the updated indices. If update fails, return an empty array. + Return the updated indices. If the update fails, return an empty array. """ if len(buffer) == 0 or self.maxsize == 0: return np.array([], int) @@ -212,17 +215,17 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( + def _update_buffer_state_after_adding_batch( self, rew: float | np.ndarray, done: bool, ) -> tuple[int, float | np.ndarray, int, int]: """Maintain the buffer's state after adding one data batch. - Return (index_to_be_modified, episode_reward, episode_length, + Return (index_to_add_at, episode_reward, episode_length, episode_start_index). """ - self.last_index[0] = ptr = self._index + self.last_index[0] = index_to_add_at = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -230,17 +233,17 @@ def _add_index( self._ep_len += 1 if done: - result = ptr, self._ep_rew, self._ep_len, self._ep_idx + result = index_to_add_at, self._ep_rew, self._ep_len, self._ep_idx self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index return result - return ptr, self._ep_rew * 0.0, 0, self._ep_idx + return index_to_add_at, self._ep_rew * 0.0, 0, self._ep_idx def add( self, batch: RolloutBatchProtocol, buffer_ids: np.ndarray | list[int] | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Add a batch of data into replay buffer. + """Add a batch of data into the replay buffer. :param batch: the input data batch. "obs", "act", "rew", "terminated", "truncated" are required keys. @@ -257,9 +260,10 @@ def add( new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( - batch.keys(), - ) # important to do after preprocess batch + if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()): + raise RuntimeError( + f"The input batch you try to add is missing the keys {missing_keys}.", + ) # important to do after batch preprocessing stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 @@ -269,14 +273,16 @@ def add( batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] - # get ptr + # get ep_add_at_idx if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done)) + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = ( + np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done) + ) try: - self._meta[ptr] = batch + self._meta[ep_add_at_idx] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) @@ -287,8 +293,8 @@ def add( self._meta = create_value(batch, self.maxsize, stack) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) - self._meta[ptr] = batch - return ptr, ep_rew, ep_len, ep_idx + self._meta[ep_add_at_idx] = batch + return ep_add_at_idx, ep_rew, ep_len, ep_start_idx def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -349,10 +355,10 @@ def get( stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. - :param str key: the key to get, should be one of the reserved_keys. + :param str key: the key to get. Should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. - :param stack_num: Default to self.stack_num. + :param stack_num: Set to self.stack_num if set to None. """ if key not in self._meta and default_value is not None: return default_value @@ -415,6 +421,6 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat "policy": self.get(indices, "policy", Batch()), } for key in self._meta.__dict__: - if key not in self._input_keys: + if key not in self._INPUT_KEYS: batch_dict[key] = self._meta[key][indices] return cast(RolloutBatchProtocol, Batch(batch_dict)) diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index a495b0ada..3099a2451 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -15,9 +15,9 @@ class ReplayBufferManager(ReplayBuffer): These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. - :param buffer_list: a list of ReplayBuffer needed to be handled. + :param buffer_list: a list of ReplayBuffer objects needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ @@ -118,8 +118,8 @@ def add( ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. - Each of the data's length (first dimension) must equal to the length of - buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + Each of the data's lengths (first dimension) must be equal to the length of + buffer_ids. By default, buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and @@ -127,11 +127,14 @@ def add( """ # preprocess batch new_batch = Batch() - for key in set(self._reserved_keys).intersection(batch.keys()): + for key in set(self._RESERVED_KEYS).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys()) + if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()): + raise RuntimeError( + f"The input batch you try to add is missing the keys {missing_keys}.", + ) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: @@ -141,21 +144,23 @@ def add( # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] + ep_add_at_idxs, ep_lens, ep_rews, ep_start_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( + ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[ + buffer_id + ]._update_buffer_state_after_adding_batch( batch.rew[batch_idx], batch.done[batch_idx], ) - ptrs.append(ptr + self._offset[buffer_id]) + ep_add_at_idxs.append(ep_add_at_idx + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) - ep_idxs.append(ep_idx + self._offset[buffer_id]) - self.last_index[buffer_id] = ptr + self._offset[buffer_id] + ep_start_idxs.append(ep_start_idx + self._offset[buffer_id]) + self.last_index[buffer_id] = ep_add_at_idx + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) - ptrs = np.array(ptrs) + ep_add_at_idxs = np.array(ep_add_at_idxs) try: - self._meta[ptrs] = batch + self._meta[ep_add_at_idxs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) @@ -166,8 +171,8 @@ def add( else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() - self._meta[ptrs] = batch - return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) + self._meta[ep_add_at_idxs] = batch + return ep_add_at_idxs, np.array(ep_rews), np.array(ep_lens), np.array(ep_start_idxs) def sample_indices(self, batch_size: int | None) -> np.ndarray: # TODO: simplify this code @@ -214,7 +219,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ @@ -235,7 +240,7 @@ class HERReplayBufferManager(ReplayBufferManager): :param buffer_list: a list of HERReplayBuffer needed to be handled. - .. seealso:: + .. see also:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 751fedfb2..edff72cfc 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,8 +1,9 @@ import time import warnings +from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast +from typing import Any, Generic, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -11,7 +12,6 @@ from tianshou.data import ( Batch, CachedReplayBuffer, - PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, @@ -122,7 +122,37 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: +_TBuffer = TypeVar("_TBuffer", bound=ReplayBuffer) + + +class CollectorBase(ABC, Generic[_TBuffer]): + @abstractmethod + def get_buffer(self) -> _TBuffer: + pass + + @abstractmethod + def reset(self) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + +class _CollectorStump(ABC, Generic[_TBuffer]): """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -138,23 +168,22 @@ class Collector: .. note:: - Please make sure the given environment has a time limitation if using n_episode - collect option. + Please make sure the given environment has a time limitation if using n_episode + collect option. .. note:: - In past versions of Tianshou, the replay buffer passed to `__init__` - was automatically reset. This is not done in the current implementation. + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. """ def __init__( self, policy: BasePolicy, env: gym.Env | BaseVectorEnv, - buffer: ReplayBuffer | None = None, + buffer: _TBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy @@ -186,7 +215,8 @@ def is_closed(self) -> bool: """Return True if the collector is closed.""" return self._is_closed - def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: + # TODO: remove + def _assign_buffer(self, buffer: _TBuffer | None) -> _TBuffer: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) @@ -197,40 +227,12 @@ def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: else: # ReplayBuffer or PrioritizedReplayBuffer assert buffer.maxsize > 0 if self.env_num > 1: - if isinstance(buffer, ReplayBuffer): - buffer_type = "ReplayBuffer" - vector_type = "VectorReplayBuffer" - if isinstance(buffer, PrioritizedReplayBuffer): - buffer_type = "PrioritizedReplayBuffer" - vector_type = "PrioritizedVectorReplayBuffer" raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " - f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" - f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", + f"Cannot use {buffer.__class__.__name__}(size={buffer.maxsize}, ...) to collect " + f"{self.env_num} envs, please use a corresponding vectorized buffer instead.", ) return buffer - def reset( - self, - reset_buffer: bool = True, - reset_stats: bool = True, - gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - """ - self.reset_env(gym_reset_kwargs=gym_reset_kwargs) - if reset_buffer: - self.reset_buffer() - if reset_stats: - self.reset_stat() - self._is_closed = False - def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 @@ -308,7 +310,122 @@ def _compute_action_policy_hidden( ) return act_RA, act_normalized_RA, policy_R, hidden_state_RH - # TODO: reduce complexity, remove the noqa + @staticmethod + def _reset_hidden_state_based_on_type( + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray + | torch.Tensor + | Batch + | None, # todo care only about the type not the data in it + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> None: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + """ + self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + if reset_buffer: + self.reset_buffer() + if reset_stats: + self.reset_stat() + self._is_closed = False + + def _validate_collect_input( + self, + n_episode: int | None, + n_step: int | None, + ) -> None: + """Check that exactly one of n_step or n_episode is specified. + Returns the idx of non-idle envs that will be used for the collection. + """ + if n_step is not None and n_episode is not None: + raise ValueError( + f"Only one of n_step or n_episode is allowed in Collector." + f"collect, got {n_step=}, {n_episode=}.", + ) + + if n_step is not None: + if n_step < 1: + raise ValueError( + f"n_step should be an integer larger than 0, but got {n_step=}.", + ) + + if n_step % self.env_num: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}). " + "This may cause extra transitions to be collected into the buffer.", + ) + + elif n_episode is not None: + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) + + else: + raise ValueError( + f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.", + ) + + +class Collector(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + """ + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: _TBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, exploration_noise) + def collect( self, n_step: int | None = None, @@ -359,36 +476,17 @@ def collect( # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} - # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." - if n_step is not None: - assert n_episode is None, ( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}." - ) - assert n_step > 0 - if n_step % self.env_num != 0: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}), " - "which may cause extra transitions being collected into the buffer.", - ) - ready_env_ids_R = np.arange(self.env_num) - elif n_episode is not None: - assert n_episode > 0 - if self.env_num > n_episode: - warnings.warn( - f"{n_episode=} should be larger than {self.env_num=} to " - f"collect at least one trajectory in each environment.", - ) + self._validate_collect_input(n_episode, n_step) + + if n_episode: ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) + else: # n_step case + ready_env_ids_R = np.arange(self.env_num) + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} start_time = time.time() @@ -492,30 +590,33 @@ def collect( # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration # Resetting envs that reached done, or removing some of them from the collection if needed (see below) - if num_episodes_done_this_iter > 0: + if np.any(done_R): # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above - env_ind_local_D = np.where(done_R)[0] - env_ind_global_D = ready_env_ids_R[env_ind_local_D] - episode_lens.extend(ep_len_R[env_ind_local_D]) - episode_returns.extend(ep_rew_R[env_ind_local_D]) - episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + local_ready_env_ids_done_D = np.where(done_R)[0] + global_ready_env_ids_done_D = ready_env_ids_R[local_ready_env_ids_done_D] + episode_lens.extend(ep_len_R[local_ready_env_ids_done_D]) + episode_returns.extend(ep_rew_R[local_ready_env_ids_done_D]) + episode_start_indices.extend(ep_idx_R[local_ready_env_ids_done_D]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. obs_reset_DO, info_reset_D = self.env.reset( - env_id=env_ind_global_D, + env_id=global_ready_env_ids_done_D, **gym_reset_kwargs, ) # Set the hidden state to zero or None for the envs that reached done # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of # this complex logic - self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + self._reset_hidden_state_based_on_type( + local_ready_env_ids_done_D, + last_hidden_state_RH, + ) # preparing for the next iteration - last_obs_RO[env_ind_local_D] = obs_reset_DO - last_info_R[env_ind_local_D] = info_reset_D + last_obs_RO[local_ready_env_ids_done_D] = obs_reset_DO + last_info_R[local_ready_env_ids_done_D] = info_reset_D # Handling the case when we have more ready envs than desired and are not done yet # @@ -541,7 +642,7 @@ def collect( # step and we still need to collect the remaining episodes to reach the breaking condition. # creating the mask - env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_to_be_ignored_ind_local_S = local_ready_env_ids_done_D[:surplus_env_num] env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) env_should_remain_R[env_to_be_ignored_ind_local_S] = False # stripping the "idle" indices, shortening the relevant quantities from R to R-S @@ -580,84 +681,40 @@ def collect( collect_speed=step_count / collect_time, ) - def _reset_hidden_state_based_on_type( - self, - env_ind_local_D: np.ndarray, - last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, - ) -> None: - if isinstance(last_hidden_state_RH, torch.Tensor): - last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] - elif isinstance(last_hidden_state_RH, np.ndarray): - last_hidden_state_RH[env_ind_local_D] = ( - None if last_hidden_state_RH.dtype == object else 0 - ) - elif isinstance(last_hidden_state_RH, Batch): - last_hidden_state_RH.empty_(env_ind_local_D) - # todo is this inplace magic and just working? +class Collector_First_K_Episodes(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. -class AsyncCollector(Collector): - """Async Collector handles async vector environment. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. - The arguments are exactly the same as :class:`~tianshou.data.Collector`, please - refer to :class:`~tianshou.data.Collector` for more detailed explanation. + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. """ def __init__( self, policy: BasePolicy, - env: BaseVectorEnv, - buffer: ReplayBuffer | None = None, + env: gym.Env | BaseVectorEnv, + buffer: _TBuffer | None = None, exploration_noise: bool = False, ) -> None: - # assert env.is_async - warnings.warn("Using async setting may collect extra transitions into buffer.") - super().__init__( - policy, - env, - buffer, - exploration_noise, - ) - # E denotes the number of parallel environments: self.env_num - # At init, E=R but during collection R <= E - # Keep in sync with reset! - self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) - self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) - self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) - self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( - self._pre_collect_hidden_state_RH, - ) - self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) - self._current_policy_in_all_envs_E: Batch | None = None - - def reset( - self, - reset_buffer: bool = True, - reset_stats: bool = True, - gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: - """Reset the environment, statistics, and data needed to start the collection. - - :param reset_buffer: if true, reset the replay buffer attached - to the collector. - :param reset_stats: if true, reset the statistics attached to the collector. - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - """ - # This sets the _pre_collect attrs - super().reset( - reset_buffer=reset_buffer, - reset_stats=reset_stats, - gym_reset_kwargs=gym_reset_kwargs, - ) - # Keep in sync with init! - self._ready_env_ids_R = np.arange(self.env_num) - # E denotes the number of parallel environments self.env_num - self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) - self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) - self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) - self._current_action_in_all_envs_EA = np.empty(self.env_num) - self._current_policy_in_all_envs_E = None + super().__init__(policy, env, buffer, exploration_noise) def collect( self, @@ -669,9 +726,555 @@ def collect( reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of steps or episodes with async env setting. + # Input validation + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + if n_step: + raise ValueError( + f"First_K_Episodes collector only supports n_episode, but got {n_step=}.", + ) + assert n_episode, "n_episode should be specified." + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + if n_episode < self.env_num: + warnings.warn( + f"{n_episode=} should be larger than or equal to {self.env_num=} " + f"(otherwise you will get idle workers and won't collect at" + f"least one trajectory in each env).", + ) - This function does not collect an exact number of transitions specified by n_step or + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + start_time = time.time() + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + + while True: + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) + # restore the state: if the last state is None, it won't store + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs + if render: + self.env.render() + if not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + if n_step: + # persist for future collect iterations + self._pre_collect_obs_RO = last_obs_RO + self._pre_collect_info_R = last_info_R + self._pre_collect_hidden_state_RH = last_hidden_state_RH + elif n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) + + +class Collector_Equal_Num_Episodes_Per_Env(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + + .. note:: + + Please make sure the given environment has a time limitation if using n_episode + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + + This is not the most efficient implementation possible. The collector + collects one episode in every environment, thus an environment having reached + done is staying idle until the "slowest" is reaching done. Non-idle envs are referred to as ready. + This could be + optimized by explicitly tracking how many trajectories have been collected in each + environment. + """ + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: _TBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + super().__init__(policy, env, buffer, exploration_noise) + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before + collecting data. + It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + # Input validation + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + if n_step: + raise ValueError( + f"First_K_Episodes collector only supports n_episode, but got {n_step=}.", + ) + assert n_episode is not None, "n_episode should be specified." # needed for mypy + if n_episode < 1: + raise ValueError( + f"{n_episode=} should be an integer larger than 0.", + ) + + if not n_episode % self.env_num == 0: + raise ValueError( + f"n_episode has to be a multiple of the number of envs, but got {n_episode=}, {self.env_num=}.", + ) + ready_env_ids_R = np.arange(self.env_num) + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + start_time = time.time() + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + + while True: + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) + # restore the state: if the last state is None, it won't store + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs + if render: + self.env.render() + if not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if np.any(done_R): + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + local_ready_env_ids_done_D = np.where(done_R)[0] + episode_lens.extend(ep_len_R[local_ready_env_ids_done_D]) + episode_returns.extend(ep_rew_R[local_ready_env_ids_done_D]) + episode_start_indices.extend(ep_idx_R[local_ready_env_ids_done_D]) + + env_array_ids_non_idle_not_done_R = np.where(~done_R)[0] + # new R, all envs that haven't reached done yet are remain ready + ready_env_ids_R = ready_env_ids_R[env_array_ids_non_idle_not_done_R] + last_obs_RO = last_obs_RO[env_array_ids_non_idle_not_done_R] + last_info_R = last_info_R[env_array_ids_non_idle_not_done_R] + last_hidden_state_RH = last_hidden_state_RH[env_array_ids_non_idle_not_done_R] # type: ignore[index] + + if len(ready_env_ids_R) == 0: + ready_env_ids_R = np.arange(self.env_num) # so now R == E again + + obs_reset_RO, info_reset_R = self.env.reset( + env_id=ready_env_ids_R, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type( + ready_env_ids_R, + self._pre_collect_hidden_state_RH, + ) # just needed to infer the type + + # preparing for the next iteration + last_obs_RO = obs_reset_RO + last_info_R = info_reset_R + last_hidden_state_RH = ( + self._pre_collect_hidden_state_RH + ) # todo what is actually correct, reset_env + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + if n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) + + +class AsyncCollector(_CollectorStump[_TBuffer], Generic[_TBuffer]): + """Async Collector handles async vector environment. + + The arguments are exactly the same as :class:`~tianshou.data.Collector`, please + refer to :class:`~tianshou.data.Collector` for more detailed explanation. + """ + + def __init__( + self, + policy: BasePolicy, + env: BaseVectorEnv, + buffer: _TBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + # assert env.is_async + warnings.warn("Using async setting may collect extra transitions into buffer.") + super().__init__( + policy, + env, + buffer, + exploration_noise, + ) + # E denotes the number of parallel environments: self.env_num + # At init, E=R but during collection R <= E + # Keep in sync with reset! + self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) + self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( + self._pre_collect_hidden_state_RH, + ) + self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) + self._current_policy_in_all_envs_E: Batch | None = None + + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> None: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + """ + # This sets the _pre_collect attrs + super().reset( + reset_buffer=reset_buffer, + reset_stats=reset_stats, + gym_reset_kwargs=gym_reset_kwargs, + ) + # Keep in sync with init! + self._ready_env_ids_R = np.arange(self.env_num) + # E denotes the number of parallel environments self.env_num + self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) + self._current_action_in_all_envs_EA = np.empty(self.env_num) + self._current_policy_in_all_envs_E = None + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + """Collect a specified number of steps or episodes with async env setting. + + This function does not collect an exact number of transitions specified by n_step or n_episode. Instead, to support the asynchronous setting, it may collect more transitions than requested by n_step or n_episode and save them into the buffer. @@ -697,26 +1300,13 @@ def collect( :return: A dataclass object """ + self._validate_collect_input(n_episode, n_step) + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} - # collect at least n_step or n_episode - if n_step is not None: - assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." - ) - assert n_step > 0 - elif n_episode is not None: - assert n_episode > 0 - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - if reset_before_collect: - # first we need to step all envs to be able to interact with them + # First we need to step all envs to be able to interact with them if self.env.waiting_id: self.env.step(None, id=self.env.waiting_id) self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f71a7f981..c394f4b72 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -119,8 +119,13 @@ def create_train_test_collector( save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, ) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, envs.test_envs) + train_collector: Collector[ReplayBuffer] = Collector( + policy, + train_envs, + buffer, + exploration_noise=True, + ) + test_collector: Collector[ReplayBuffer] = Collector(policy, envs.test_envs) if reset_collectors: train_collector.reset() test_collector.reset() diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 17f0550fc..b5df56bfe 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -9,7 +9,7 @@ import numpy as np import torch -from tianshou.data import Collector, InfoStats +from tianshou.data import Collector, InfoStats, VectorReplayBuffer from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, @@ -310,7 +310,7 @@ def _watch_agent( render: float, ) -> None: policy.eval() - collector = Collector(policy, env) + collector: Collector[VectorReplayBuffer] = Collector(policy, env) result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy