diff --git a/rl_toolkit/agents/dueling_dqn/agent.py b/rl_toolkit/agents/dueling_dqn/agent.py index e04d456..02c3b42 100644 --- a/rl_toolkit/agents/dueling_dqn/agent.py +++ b/rl_toolkit/agents/dueling_dqn/agent.py @@ -134,10 +134,14 @@ def collect(self, writer, policy): action = np.array(action, copy=False, dtype=self._env.action_space.dtype) # Perform action - new_obs, ext_reward, terminated, truncated, _ = self._env.step(action) + new_obs, ext_reward, terminated, truncated, info = self._env.step(action) # Update variables - self._episode_reward += ext_reward + try: + self._episode_reward = info["score"] + except KeyError: + self._episode_reward += ext_reward + self._episode_steps += 1 self._total_steps += 1