Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
markub3327 committed Nov 17, 2023
1 parent 0f4f7ca commit 1992a97
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rl_toolkit/networks/models/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _compute_n_step_rewards(
self, rewards, discount_factor, next_state_value, is_terminal
):
n = tf.shape(rewards)[1]
rewards = tf.squeeze(rewards, axis=-1)

# Create a discount factor tensor
discounts = discount_factor ** tf.range(n + 1, dtype=rewards.dtype)
Expand Down Expand Up @@ -234,9 +235,8 @@ def _compute_n_step_rewards(
# Add the next state value with discount
n_step_rewards += (
(1.0 - is_terminal[:, tf.newaxis])
* tf.reverse(discounts[1:], axis=[0])[
tf.newaxis, :
] * next_state_value[:, tf.newaxis]
* tf.reverse(discounts[1:], axis=[0])[tf.newaxis, :]
* next_state_value[:, tf.newaxis]
)

return n_step_rewards
Expand Down

0 comments on commit 1992a97

Please sign in to comment.