diff --git a/rl_toolkit/networks/models/dueling.py b/rl_toolkit/networks/models/dueling.py index f526e4e..81bee4b 100644 --- a/rl_toolkit/networks/models/dueling.py +++ b/rl_toolkit/networks/models/dueling.py @@ -198,6 +198,7 @@ def _compute_n_step_rewards( self, rewards, discount_factor, next_state_value, is_terminal ): n = tf.shape(rewards)[1] + rewards = tf.squeeze(rewards, axis=-1) # Create a discount factor tensor discounts = discount_factor ** tf.range(n + 1, dtype=rewards.dtype) @@ -234,9 +235,8 @@ def _compute_n_step_rewards( # Add the next state value with discount n_step_rewards += ( (1.0 - is_terminal[:, tf.newaxis]) - * tf.reverse(discounts[1:], axis=[0])[ - tf.newaxis, : - ] * next_state_value[:, tf.newaxis] + * tf.reverse(discounts[1:], axis=[0])[tf.newaxis, :] + * next_state_value[:, tf.newaxis] ) return n_step_rewards