diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6a31f4f19..9e0c7ec8d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.11.0a4 +image: stablebaselines/stable-baselines3-cpu:0.11.1 type-check: script: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5d12f4636..64958df83 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,15 @@ Changelog ========== +Pre-Release 0.11.1 (2021-02-27) +------------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Fixed a bug where ``train_freq`` was not properly converted when loading a saved model + + + Pre-Release 0.11.0 (2021-02-27) ------------------------------- diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 09cc2aae7..164174b19 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -131,15 +131,8 @@ def __init__( # see https://github.com/hill-a/stable-baselines/issues/863 self.remove_time_limit_termination = remove_time_limit_termination - if isinstance(train_freq, int): - train_freq = (train_freq, "step") - - try: - train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) - except ValueError: - raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") - - self.train_freq = TrainFreq(*train_freq) + # Save train freq parameter, will be converted later to TrainFreq object + self.train_freq = train_freq self.actor = None # type: Optional[th.nn.Module] self.replay_buffer = None # type: Optional[ReplayBuffer] @@ -149,6 +142,28 @@ def __init__( # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup + def _convert_train_freq(self) -> None: + """ + Convert `train_freq` parameter (int or tuple) + to a TrainFreq object. + """ + if not isinstance(self.train_freq, TrainFreq): + train_freq = self.train_freq + + # The value of the train frequency will be checked later + if not isinstance(train_freq, tuple): + train_freq = (train_freq, "step") + + try: + train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) + except ValueError: + raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") + + if not isinstance(train_freq[0], int): + raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") + + self.train_freq = TrainFreq(*train_freq) + def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) @@ -167,6 +182,9 @@ def _setup_model(self) -> None: ) self.policy = self.policy.to(self.device) + # Convert train freq parameter to TrainFreq object + self._convert_train_freq() + def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Save the replay buffer as a pickle file. diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index d9df1bbc0..af88ba824 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.11.0 +0.11.1 diff --git a/tests/test_run.py b/tests/test_run.py index fae6782d6..c588a0257 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -103,3 +103,39 @@ def test_dqn(): create_eval_env=True, ) model.learn(total_timesteps=500, eval_freq=250) + + +@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")]) +def test_train_freq(tmp_path, train_freq): + + model = SAC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64], n_critics=1), + learning_starts=100, + buffer_size=10000, + verbose=1, + train_freq=train_freq, + ) + model.learn(total_timesteps=150) + model.save(tmp_path / "test_save.zip") + env = model.get_env() + model = SAC.load(tmp_path / "test_save.zip", env=env) + model.learn(total_timesteps=150) + model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env) + model.learn(total_timesteps=150) + + +@pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")]) +def test_train_freq_fail(train_freq): + with pytest.raises(ValueError): + model = SAC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64], n_critics=1), + learning_starts=100, + buffer_size=10000, + verbose=1, + train_freq=train_freq, + ) + model.learn(total_timesteps=250) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index bcc85b657..369b82a49 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -176,7 +176,7 @@ def test_set_env(model_class): kwargs = {} if model_class in {DQN, DDPG, SAC, TD3}: - kwargs = dict(learning_starts=100) + kwargs = dict(learning_starts=100, train_freq=4) elif model_class in {A2C, PPO}: kwargs = dict(n_steps=64) @@ -238,12 +238,12 @@ def test_save_load_env_cnn(tmp_path, model_class): env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False) kwargs = dict(policy_kwargs=dict(net_arch=[32])) if model_class == TD3: - kwargs.update(dict(buffer_size=100, learning_starts=50)) + kwargs.update(dict(buffer_size=100, learning_starts=50, train_freq=4)) model = model_class("CnnPolicy", env, **kwargs).learn(100) model.save(tmp_path / "test_save") # Test loading with env and continuing training - model = model_class.load(str(tmp_path / "test_save.zip"), env=env).learn(100) + model = model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100) # clear file from os os.remove(tmp_path / "test_save.zip")