Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/collect equal episode num in all envs #1127

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
38 changes: 37 additions & 1 deletion test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
ReplayBuffer,
VectorReplayBuffer,
)
from tianshou.data.collector import (
Collector_Equal_Num_Episodes_Per_Env,
Collector_First_K_Episodes,
)
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy

Expand Down Expand Up @@ -201,7 +205,7 @@ def test_collector() -> None:
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError):
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
with pytest.raises(ValueError):
c_dummy_venv_4_envs.collect()

# test NXEnv
Expand All @@ -213,6 +217,38 @@ def test_collector() -> None:
assert c_suproc_new.buffer.obs.dtype == object


def test_sample_first_k_episodes_collector() -> None:
env_lens = [1, 8, 9, 10]
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens]
venv = SubprocVectorEnv(env_fns)
policy = MaxActionPolicy()
c1 = Collector_First_K_Episodes(policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4))
c1.reset()
res = c1.collect(n_episode=12)
assert np.array_equal(res.lens, np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1, 10]))


def test_Collector_Equal_Num_Episodes_Per_Env() -> None:
env_lens = [1, 8, 9, 10]
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in env_lens]
venv = SubprocVectorEnv(env_fns)
policy = MaxActionPolicy()
c1 = Collector_Equal_Num_Episodes_Per_Env(
policy,
venv,
VectorReplayBuffer(total_size=100, buffer_num=4),
)
c1.reset()
with pytest.raises(ValueError) as excinfo:
c1.collect(n_episode=9)
assert (
str(excinfo.value)
== "n_episode has to be a multiple of the number of envs, but got n_episode=9, self.env_num=4."
)
res = c1.collect(n_episode=12)
assert np.array_equal(res.lens, np.array([1, 8, 9, 10, 1, 8, 9, 10, 1, 8, 9, 10]))


@pytest.fixture()
def get_AsyncCollector():
env_lens = [2, 3, 4, 5]
Expand Down
64 changes: 35 additions & 29 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ class ReplayBuffer:
:doc:`/01_tutorials/01_concepts`.

:param size: the maximum size of replay buffer.
:param stack_num: the frame-stack sampling argument, should be greater than or
equal to 1. Default to 1 (no stacking).
:param ignore_obs_next: whether to not store obs_next. Default to False.
:param stack_num: the frame-stack sampling argument. It should be greater than or
equal to 1 (no stacking).
:param ignore_obs_next: whether to not store obs_next.
:param save_only_last_obs: only save the last obs/obs_next when it has a shape
of (timestep, ...) because of temporal stacking. Default to False.
of (timestep, ...) because of temporal stacking.
:param sample_avail: the parameter indicating sampling only available index
when using frame-stack sampling method. Default to False.
when using frame-stack sampling method.
"""

_reserved_keys = (
_RESERVED_KEYS = (
"obs",
"act",
"rew",
Expand All @@ -41,7 +41,10 @@ class ReplayBuffer:
"info",
"policy",
)
_input_keys = (

_REQUIRED_KEYS = frozenset({"obs", "act", "rew", "terminated", "truncated", "done"})

_INPUT_KEYS = (
"obs",
"act",
"rew",
Expand Down Expand Up @@ -104,7 +107,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:

def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned"
assert key not in self._RESERVED_KEYS, f"key '{key}' is reserved and cannot be assigned"
super().__setattr__(key, value)

def save_hdf5(self, path: str, compression: str | None = None) -> None:
Expand Down Expand Up @@ -162,7 +165,7 @@ def reset(self, keep_statistics: bool = False) -> None:
def set_batch(self, batch: RolloutBatchProtocol) -> None:
"""Manually choose the batch you want the ReplayBuffer to manage."""
assert len(batch) == self.maxsize and set(batch.keys()).issubset(
self._reserved_keys,
self._RESERVED_KEYS,
), "Input batch doesn't meet ReplayBuffer's data form requirement."
self._meta = batch

Expand All @@ -181,17 +184,17 @@ def prev(self, index: int | np.ndarray) -> np.ndarray:
return (index + end_flag) % self._size

def next(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of next transition.
"""Return the index of the next transition.

The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size

def update(self, buffer: "ReplayBuffer") -> np.ndarray:
"""Move the data from the given buffer to current buffer.
"""Move the data from the given buffer to the current buffer.

Return the updated indices. If update fails, return an empty array.
Return the updated indices. If the update fails, return an empty array.
"""
if len(buffer) == 0 or self.maxsize == 0:
return np.array([], int)
Expand All @@ -212,35 +215,35 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices

def _add_index(
def _update_buffer_state_after_adding_batch(
self,
rew: float | np.ndarray,
done: bool,
) -> tuple[int, float | np.ndarray, int, int]:
"""Maintain the buffer's state after adding one data batch.

Return (index_to_be_modified, episode_reward, episode_length,
Return (index_to_add_at, episode_reward, episode_length,
episode_start_index).
"""
self.last_index[0] = ptr = self._index
self.last_index[0] = index_to_add_at = self._index
self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize

self._ep_rew += rew
self._ep_len += 1

if done:
result = ptr, self._ep_rew, self._ep_len, self._ep_idx
result = index_to_add_at, self._ep_rew, self._ep_len, self._ep_idx
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index
return result
return ptr, self._ep_rew * 0.0, 0, self._ep_idx
return index_to_add_at, self._ep_rew * 0.0, 0, self._ep_idx

def add(
self,
batch: RolloutBatchProtocol,
buffer_ids: np.ndarray | list[int] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into replay buffer.
"""Add a batch of data into the replay buffer.

:param batch: the input data batch. "obs", "act", "rew",
"terminated", "truncated" are required keys.
Expand All @@ -257,9 +260,10 @@ def add(
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(
batch.keys(),
) # important to do after preprocess batch
if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()):
raise RuntimeError(
f"The input batch you try to add is missing the keys {missing_keys}.",
) # important to do after batch preprocessing
stacked_batch = buffer_ids is not None
if stacked_batch:
assert len(batch) == 1
Expand All @@ -269,14 +273,16 @@ def add(
batch.pop("obs_next", None)
elif self._save_only_last_obs:
batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
# get ptr
# get ep_add_at_idx
if stacked_batch:
rew, done = batch.rew[0], batch.done[0]
else:
rew, done = batch.rew, batch.done
ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done))
ep_add_at_idx, ep_rew, ep_len, ep_start_idx = (
np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done)
)
try:
self._meta[ptr] = batch
self._meta[ep_add_at_idx] = batch
except ValueError:
stack = not stacked_batch
batch.rew = batch.rew.astype(float)
Expand All @@ -287,8 +293,8 @@ def add(
self._meta = create_value(batch, self.maxsize, stack) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
self._meta[ptr] = batch
return ptr, ep_rew, ep_len, ep_idx
self._meta[ep_add_at_idx] = batch
return ep_add_at_idx, ep_rew, ep_len, ep_start_idx

def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Expand Down Expand Up @@ -349,10 +355,10 @@ def get(
stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``.

:param index: the index for getting stacked data.
:param str key: the key to get, should be one of the reserved_keys.
:param str key: the key to get. Should be one of the reserved_keys.
:param default_value: if the given key's data is not found and default_value is
set, return this default_value.
:param stack_num: Default to self.stack_num.
:param stack_num: Set to self.stack_num if set to None.
"""
if key not in self._meta and default_value is not None:
return default_value
Expand Down Expand Up @@ -415,6 +421,6 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat
"policy": self.get(indices, "policy", Batch()),
}
for key in self._meta.__dict__:
if key not in self._input_keys:
if key not in self._INPUT_KEYS:
batch_dict[key] = self._meta[key][indices]
return cast(RolloutBatchProtocol, Batch(batch_dict))
39 changes: 22 additions & 17 deletions tianshou/data/buffer/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class ReplayBufferManager(ReplayBuffer):
These replay buffers have contiguous memory layout, and the storage space each
buffer has is a shallow copy of the topmost memory.

:param buffer_list: a list of ReplayBuffer needed to be handled.
:param buffer_list: a list of ReplayBuffer objects needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand Down Expand Up @@ -118,20 +118,23 @@ def add(
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into ReplayBufferManager.

Each of the data's length (first dimension) must equal to the length of
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
Each of the data's lengths (first dimension) must be equal to the length of
buffer_ids. By default, buffer_ids is [0, 1, ..., buffer_num - 1].

Return (current_index, episode_reward, episode_length, episode_start_index). If
the episode is not finished, the return value of episode_length and
episode_reward is 0.
"""
# preprocess batch
new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()):
for key in set(self._RESERVED_KEYS).intersection(batch.keys()):
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys())
if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()):
raise RuntimeError(
f"The input batch you try to add is missing the keys {missing_keys}.",
)
if self._save_only_last_obs:
batch.obs = batch.obs[:, -1]
if not self._save_obs_next:
Expand All @@ -141,21 +144,23 @@ def add(
# get index
if buffer_ids is None:
buffer_ids = np.arange(self.buffer_num)
ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []
ep_add_at_idxs, ep_lens, ep_rews, ep_start_idxs = [], [], [], []
for batch_idx, buffer_id in enumerate(buffer_ids):
ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[
buffer_id
]._update_buffer_state_after_adding_batch(
batch.rew[batch_idx],
batch.done[batch_idx],
)
ptrs.append(ptr + self._offset[buffer_id])
ep_add_at_idxs.append(ep_add_at_idx + self._offset[buffer_id])
ep_lens.append(ep_len)
ep_rews.append(ep_rew)
ep_idxs.append(ep_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
ep_start_idxs.append(ep_start_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ep_add_at_idx + self._offset[buffer_id]
self._lengths[buffer_id] = len(self.buffers[buffer_id])
ptrs = np.array(ptrs)
ep_add_at_idxs = np.array(ep_add_at_idxs)
try:
self._meta[ptrs] = batch
self._meta[ep_add_at_idxs] = batch
except ValueError:
batch.rew = batch.rew.astype(float)
batch.done = batch.done.astype(bool)
Expand All @@ -166,8 +171,8 @@ def add(
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
self._set_batch_for_children()
self._meta[ptrs] = batch
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
self._meta[ep_add_at_idxs] = batch
return ep_add_at_idxs, np.array(ep_rews), np.array(ep_lens), np.array(ep_start_idxs)

def sample_indices(self, batch_size: int | None) -> np.ndarray:
# TODO: simplify this code
Expand Down Expand Up @@ -214,7 +219,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage

:param buffer_list: a list of PrioritizedReplayBuffer needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand All @@ -235,7 +240,7 @@ class HERReplayBufferManager(ReplayBufferManager):

:param buffer_list: a list of HERReplayBuffer needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand Down
Loading