Skip to content

Commit

Permalink
Fix train_freq at load time (#332)
Browse files Browse the repository at this point in the history
* Fix train_freq loading

* Update docker

* Add sanity checks + tests for train freq
  • Loading branch information
araffin authored Feb 27, 2021
1 parent 0c50d75 commit b2c94a6
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.11.0a4
image: stablebaselines/stable-baselines3-cpu:0.11.1

type-check:
script:
Expand Down
9 changes: 9 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
Changelog
==========

Pre-Release 0.11.1 (2021-02-27)
-------------------------------

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``train_freq`` was not properly converted when loading a saved model



Pre-Release 0.11.0 (2021-02-27)
-------------------------------

Expand Down
36 changes: 27 additions & 9 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,8 @@ def __init__(
# see https://github.com/hill-a/stable-baselines/issues/863
self.remove_time_limit_termination = remove_time_limit_termination

if isinstance(train_freq, int):
train_freq = (train_freq, "step")

try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")

self.train_freq = TrainFreq(*train_freq)
# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq

self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer = None # type: Optional[ReplayBuffer]
Expand All @@ -149,6 +142,28 @@ def __init__(
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup

def _convert_train_freq(self) -> None:
"""
Convert `train_freq` parameter (int or tuple)
to a TrainFreq object.
"""
if not isinstance(self.train_freq, TrainFreq):
train_freq = self.train_freq

# The value of the train frequency will be checked later
if not isinstance(train_freq, tuple):
train_freq = (train_freq, "step")

try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")

if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

self.train_freq = TrainFreq(*train_freq)

def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
Expand All @@ -167,6 +182,9 @@ def _setup_model(self) -> None:
)
self.policy = self.policy.to(self.device)

# Convert train freq parameter to TrainFreq object
self._convert_train_freq()

def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Save the replay buffer as a pickle file.
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0
0.11.1
36 changes: 36 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,39 @@ def test_dqn():
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)


@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
def test_train_freq(tmp_path, train_freq):

model = SAC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=150)
model.save(tmp_path / "test_save.zip")
env = model.get_env()
model = SAC.load(tmp_path / "test_save.zip", env=env)
model.learn(total_timesteps=150)
model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env)
model.learn(total_timesteps=150)


@pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")])
def test_train_freq_fail(train_freq):
with pytest.raises(ValueError):
model = SAC(
"MlpPolicy",
"Pendulum-v0",
policy_kwargs=dict(net_arch=[64, 64], n_critics=1),
learning_starts=100,
buffer_size=10000,
verbose=1,
train_freq=train_freq,
)
model.learn(total_timesteps=250)
6 changes: 3 additions & 3 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_set_env(model_class):

kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
kwargs = dict(learning_starts=100)
kwargs = dict(learning_starts=100, train_freq=4)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)

Expand Down Expand Up @@ -238,12 +238,12 @@ def test_save_load_env_cnn(tmp_path, model_class):
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False)
kwargs = dict(policy_kwargs=dict(net_arch=[32]))
if model_class == TD3:
kwargs.update(dict(buffer_size=100, learning_starts=50))
kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4))

model = model_class("CnnPolicy", env, **kwargs).learn(100)
model.save(tmp_path / "test_save")
# Test loading with env and continuing training
model = model_class.load(str(tmp_path / "test_save.zip"), env=env).learn(100)
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
# clear file from os
os.remove(tmp_path / "test_save.zip")

Expand Down

0 comments on commit b2c94a6

Please sign in to comment.