Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support asynchronous human preference gathering in RLHP implementation #716

Open
wants to merge 165 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
165 commits
Select commit Hold shift + click to select a range
63e9112
Adds RenderImageInfoWrapper
rk1a Dec 15, 2022
ed5362c
Add PrefCollectGatherer
rk1a Dec 15, 2022
9359ad4
Adds configuration for post wrappers
rk1a Dec 15, 2022
eec88e0
Adds named_config for human preferences
rk1a Dec 15, 2022
346efde
Add experiment script for human preferences
rk1a Dec 15, 2022
e6a4400
Fix post wrappers config
rk1a Dec 16, 2022
4fa0726
Fix post wrappers config
rk1a Dec 30, 2022
ae0087f
Fix PrefCollect address
rk1a Dec 30, 2022
5257b11
Merge branch 'human-pref-gatherer' of https://github.com/rk1a/imitati…
rk1a Dec 30, 2022
4b3968b
Extract preference querying into querent class
rk1a Feb 15, 2023
07c3866
Merge branch 'human-pref-gatherer' into human-prefs
rk1a Feb 15, 2023
8ccaa71
Add querent and gatherer for human preferences
rk1a Mar 1, 2023
3883a9f
Add PreferenceQuerent tests, one for PrefCollectQuerent
Mar 9, 2023
05f1da7
Correct method signature
Mar 16, 2023
7ba60c3
Test PreferenceGatherer and partially test SyntheticGatherer
Mar 16, 2023
65acb03
Fix bug
Apr 17, 2023
2185a1e
Add gather preferences tests
Apr 17, 2023
f7b22f2
Add pref collect gatherer tests
Apr 28, 2023
c0e884d
Add PrefCollectGatherer tests
rk1a Apr 28, 2023
fa587e5
Add todos
May 5, 2023
b27b8eb
Merge remote-tracking branch 'origin/human-prefs' into human-prefs
May 5, 2023
9208b89
Add support for a simple preference UI
timbauman May 11, 2023
97c2171
style
timbauman May 11, 2023
5182acc
style
timbauman May 11, 2023
f176bdc
comments
timbauman May 11, 2023
2540ba4
add notebook
timbauman May 11, 2023
3ff9579
tutorial
timbauman May 11, 2023
ce5efa7
lint
timbauman May 11, 2023
2ba2994
lint
timbauman May 11, 2023
b3fd565
more lint
timbauman May 12, 2023
002146f
oops
timbauman May 12, 2023
9bb1f00
add test
timbauman May 13, 2023
f7b7a0d
lint
timbauman May 13, 2023
4e50ca1
Merge branch 'master' into human-prefs
May 15, 2023
4b74f50
Merge remote-tracking branch 'imitation/711-support-simple-synchronou…
May 15, 2023
d82542c
Integrate SynchronousHumanPreferenceGatherer
May 15, 2023
72ad6aa
Fix remaining tests
May 19, 2023
5be9087
Fix flake8
May 25, 2023
8e3c7b7
Fix flake8 and codespell
rk1a May 25, 2023
b23abcd
Fix some mypy errors
rk1a Jun 8, 2023
b6fb914
Fix mypy
Jun 8, 2023
d230f44
Fix mypy error
rk1a Jun 8, 2023
c7d9010
Fix mypy
Jun 8, 2023
2de7491
Fix mypy, flake8, codespell
rk1a Jun 25, 2023
c7f6bac
Adds preference_querent to args-is-none-check
Jun 26, 2023
06a8e1e
Merge remote-tracking branch 'origin/human-prefs' into human-prefs
Jun 26, 2023
9f1de17
Adds querent to pref comparisons in tutorial 5
Jul 2, 2023
27881a7
Adds querent to pref comparisons in tutorial 5a
Jul 2, 2023
69983ef
Adds querent to pref comparisons in tutorial 5b and docs
Jul 2, 2023
54be4be
Fix notebooks and docs
rk1a Jul 2, 2023
657a8b2
Fix bug
rk1a Jul 2, 2023
a056476
Removes active check
Jun 27, 2023
a2ec16d
Adds missing hyphen
Jun 27, 2023
2bf5c94
Merge remote-tracking branch 'origin/human-prefs' into human-prefs
Aug 4, 2023
c2e7686
Merge branch 'master' into human-prefs
Aug 4, 2023
97d4565
Adds rng argument to querent
rk1a Aug 16, 2023
a861ce8
Add querent default config
rk1a Aug 16, 2023
fa70a3b
Fix tests
rk1a Aug 16, 2023
e0b44e5
Fix test
Aug 25, 2023
30578ca
Fix errors related to post_wrappers
rk1a Sep 5, 2023
d15cab1
Fix tests
rk1a Sep 14, 2023
f0728cd
Fix test
Sep 19, 2023
deaca12
Merge remote-tracking branch 'origin/human-prefs' into human-prefs
Sep 19, 2023
2a15d82
Merge master
Sep 25, 2023
4dbf20c
Fix test
rk1a Oct 6, 2023
ef1134a
Refactors gatherer and adapts some tests
rk1a Oct 6, 2023
046183f
Fixes gym import bug
Nov 2, 2023
d8c5d98
Fixes bug
Nov 2, 2023
d341090
Fix tests
Nov 2, 2023
f8c42bc
Fix tests
rk1a Nov 2, 2023
bb04449
Merge remote-tracking branch 'origin/refactor-gatherer' into refactor…
Nov 2, 2023
ef11c41
Merge branch 'refactor-gatherer' into human-prefs
Nov 2, 2023
4c62f6c
Merge branch 'master' into human-prefs
Nov 2, 2023
b74bb5a
Fix tests
rk1a Nov 6, 2023
ba20500
Adapt render image infor wrapper to gymnasium
Nov 10, 2023
92d3395
Merge remote-tracking branch 'origin/human-prefs' into human-prefs
Nov 10, 2023
92bbf95
Add querent kwargs and change config accordingly
Nov 14, 2023
6b0e92e
Fix default video dir
rk1a Nov 21, 2023
0c94fec
Merge branch 'HumanCompatibleAI:master' into human-prefs
rk1a Dec 10, 2023
85f01d4
Add test for video writing, fix some precommit errors
rk1a Dec 10, 2023
be8e0a8
Add and fix more tests
rk1a Dec 10, 2023
7f6be46
add test for remove_rendered_images
rk1a Jan 3, 2024
ba25756
Add test for preference comparisons with collected preferences
Jan 11, 2024
5e780b1
Merge branch 'human-prefs' of https://github.com/rk1a/imitation into …
rk1a Jan 11, 2024
cc03039
Add tests and fix bug for RenderImageWrapper
rk1a Jan 22, 2024
2632a88
Add missing docstring
rk1a Jan 22, 2024
6cf798b
Merge branch 'master-upstream' into human-prefs
Jan 24, 2024
4876d60
Cleanup changes made to docs
Jan 24, 2024
356defa
Remove benchmark_summary.md
Feb 15, 2024
4c7031e
Initial commit of Zooniverse preference comparisons
Feb 29, 2024
821d1a3
Add querent_kwargs to ZooniverseGatherer
Feb 29, 2024
b68d61d
Refactor Zooniverse elements handling
Feb 29, 2024
1cf4ca2
Override PrefCollectGatherer self.querent
Feb 29, 2024
a285639
Remove querent_kwargs from super().__init__
Feb 29, 2024
ec126d8
Handle querent_kwargs requirements
Feb 29, 2024
f0946f0
Handle querent_kwargs requirements
Mar 1, 2024
16416ec
Pass pref_collect_address to ZooniverseQuerent
Mar 1, 2024
75db939
Handle querent_kwargs requirements
Mar 1, 2024
ba48d73
ZooniverseGatherer fix super().__init__ call.
Mar 1, 2024
3634e60
ZooniverseGatherer fix super().__init__ call.
Mar 1, 2024
d04b505
ZooniverseGatherer fix super().__init__ call.
Mar 1, 2024
5bd2cbd
ZooniverseGatherer fix super().__init__ call.
Mar 1, 2024
9423e9b
Remove experiment_id attr from ZooniverseGatherer.
Mar 1, 2024
90760fd
Remove self from method call.
Mar 1, 2024
f00ad8f
Fix incorrect var names
Mar 1, 2024
684550d
Remove repeated makedirs call.
Mar 1, 2024
4abc06d
Fix output_file_name
dr-darryl-wright Mar 4, 2024
4430c32
Make video_fps an attribute
dr-darryl-wright Mar 4, 2024
ab23fda
Override Querent __call__ to write .mp4
dr-darryl-wright Mar 4, 2024
5db471f
Call PreferenceQuerent not PrefCollectQuerent
dr-darryl-wright Mar 4, 2024
26da521
Write .webm
dr-darryl-wright Mar 4, 2024
3b51a52
Write .mp4
dr-darryl-wright Mar 4, 2024
316f4b3
Minor refactor
dr-darryl-wright Mar 4, 2024
6392ea4
Fix annotation_to_label mapping
dr-darryl-wright Mar 4, 2024
92f3341
Fix annotation_to_label mapping
dr-darryl-wright Mar 4, 2024
9044f5e
Invert subjet_to_query map
dr-darryl-wright Mar 4, 2024
359a270
Invert subjet_to_query map
dr-darryl-wright Mar 4, 2024
04d9e23
Process classifications with each gather call
dr-darryl-wright Mar 4, 2024
93e86bd
Fix Counter import
dr-darryl-wright Mar 4, 2024
6a3c3cd
Allow writing .gif
dr-darryl-wright Mar 5, 2024
a28fb20
Allow writing .gif
dr-darryl-wright Mar 5, 2024
fa69e57
Remove fps from write_gif call
dr-darryl-wright Mar 5, 2024
b7632f3
Add ffmpeg and logger to write_gif call
dr-darryl-wright Mar 5, 2024
8e51427
Add ffmpeg and logger to write_gif call
dr-darryl-wright Mar 5, 2024
354ea10
Zoo authenticate for each query and gather call to avoid timeout
Mar 8, 2024
1bca7bf
Do not move retired subjects to a separate subject set. This avoids a…
Mar 8, 2024
04a0775
Add None annotation to label map.
Mar 8, 2024
28d5f09
Make last_id trackable.
Mar 10, 2024
50af8bc
Make last_id trackable.
Mar 10, 2024
4640623
Make last_id trackable.
Mar 10, 2024
0261e99
Fix last_id NameError
Mar 10, 2024
f39849f
Fix UnboundLocalError
Mar 10, 2024
67b1dd8
Handle deletion of subjects on panoptes FE
Mar 21, 2024
55c0e52
Handle deletion of subjects on panoptes FE
Mar 21, 2024
cd45613
Extract video writing from PrefCollectGatherer to new parent class
Mar 27, 2024
389b719
Extract handling of asynchronous preference collection to new parent …
Mar 27, 2024
9e57c81
Adapt SynchronousHumanGatherer to use new VideoBasedQuerent
Mar 27, 2024
b7e5382
Rename pref collect gatherer/querent to REST gatherer/querent
Mar 27, 2024
58a5e34
Merge branch 'human-prefs-base' into zoo-prefs-base
Mar 27, 2024
a951dc5
Adjust ZooniverseGatherer/Querent to new base classes
Mar 27, 2024
fb94706
Remove Zooniverse classes
Mar 27, 2024
ee40083
Add empty _query method to VideoBasedQuerent
Mar 27, 2024
cb1be27
Split video writing into smaller methods
Mar 27, 2024
0646dc7
Remove unused imports
Mar 27, 2024
09ddc0e
Remove support for observations in video writing
rk1a Apr 8, 2024
e64dc00
Merge pull request #3 from rk1a/refactor-human-prefs
rk1a Apr 8, 2024
5d9a160
Transfer video data via REST request and simplify VideoBasedQuerent
rk1a Apr 8, 2024
a3175ac
Fix tests as far as possible
rk1a Apr 28, 2024
7851d0f
Refactor Gatherer classes
May 7, 2024
df49129
Add test for video loading method
rk1a May 16, 2024
8a6e210
Fix bug in SyntheticGatherer
May 17, 2024
f934886
Adapt ConcretePreferenceGatherer
May 17, 2024
040af8b
Improve variable naming
May 17, 2024
891e7f7
Rename test for CommandLineGatherer
May 17, 2024
4b9d1f9
Fix test
May 17, 2024
2c632d8
Add documentation
May 17, 2024
eebb31d
Add documentation
May 17, 2024
f2f722a
Make method static
May 17, 2024
4c16550
Remove whitespace
May 17, 2024
7ba6e5f
Merge branch 'human-prefs-base' of https://github.com/rk1a/imitation …
rk1a May 17, 2024
f3fde8a
Merge remote-tracking branch 'origin/human-prefs-base' into human-pre…
May 17, 2024
7b2a260
Refine rest interface and add documentation
May 17, 2024
38c7df9
Fix bug and integration test
May 28, 2024
fcfa92a
Delete videos after test
May 28, 2024
d263216
Merge pull request #4 from rk1a/human-prefs-base
rk1a May 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ci/clean_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class UncleanNotebookError(Exception):
"metadata": {"do": "constant", "value": dict()},
"source": {"do": "keep"},
"id": {"do": "keep"},
"attachments": {"do": "constant", "value": {}},
}

code_structure: Dict[str, Dict[str, Any]] = {
Expand Down Expand Up @@ -76,7 +77,8 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:
if key not in structure[cell["cell_type"]]:
if check_only:
raise UncleanNotebookError(
f"Notebook {file} has unknown cell key {key}",
f"Notebook {file} has unknown cell key {key} for cell type "
+ f"{cell['cell_type']}",
)
del cell[key]
was_dirty = True
Expand Down Expand Up @@ -108,7 +110,12 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:


def parse_args():
"""Parse command-line arguments."""
"""Parse command-line arguments.

Returns:
parser: The parser object.
args: The parsed arguments.
"""
# if the argument --check has been passed, check if the notebooks are clean
# otherwise, clean them in-place
parser = argparse.ArgumentParser()
Expand All @@ -125,7 +132,14 @@ def parse_args():


def get_files(input_paths: List):
"""Build list of files to scan from list of paths and files."""
"""Build list of files to scan from list of paths and files.

Args:
input_paths: List of paths and files to scan.

Returns:
files: List of files to scan.
"""
files = []
for file in input_paths:
if file.is_dir():
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ If you use ``imitation`` in your research project, please cite our paper to help
tutorials/4_train_airl
tutorials/5_train_preference_comparisons
tutorials/5a_train_preference_comparisons_with_cnn
tutorials/5b_train_preference_comparisons_with_synchronous_human_feedback
tutorials/6_train_mce
tutorials/7_train_density
tutorials/8_train_sqil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/5_train_preference_comparisons.ipynb)\n",
"# Learning a Reward Function using Preference Comparisons with Synchronous Human Feedback\n",
"\n",
"You can request human feedback via synchronous CLI or Notebook interactions as well. The setup is only slightly different than it would be with a synthetic preference gatherer."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's the starting setup. The major differences from the synthetic setup are indicated with comments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"import random\n",
"import tempfile\n",
"from imitation.algorithms import preference_comparisons\n",
"from imitation.rewards.reward_nets import BasicRewardNet\n",
"from imitation.util import video_wrapper\n",
"from imitation.util.networks import RunningNorm\n",
"from imitation.util.util import make_vec_env\n",
"from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n",
"import gym\n",
"from stable_baselines3 import PPO\n",
"import numpy as np\n",
"\n",
"# Add a temporary directory for video recordings of trajectories. Unfortunately Jupyter\n",
"# won't play videos outside the current directory, so we have to put them here. We'll\n",
"# delete them at the end of the script.\n",
"video_dir = tempfile.mkdtemp(dir=\".\", prefix=\"videos_\")\n",
"\n",
"rng = np.random.default_rng(0)\n",
"\n",
"# Add a video wrapper to the environment. This will record videos of the agent's\n",
"# trajectories so we can review them later.\n",
"venv = make_vec_env(\n",
" \"Pendulum-v1\",\n",
" rng=rng,\n",
" post_wrappers={\n",
" \"VideoWrapper\": video_wrapper.video_wrapper_factory(\n",
" pathlib.Path(video_dir), single_video=False\n",
" )\n",
" },\n",
")\n",
"\n",
"reward_net = BasicRewardNet(\n",
" venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n",
")\n",
"\n",
"fragmenter = preference_comparisons.RandomFragmenter(\n",
" warning_threshold=0,\n",
" rng=rng,\n",
")\n",
"\n",
"querent = preference_comparisons.PreferenceQuerent()\n",
"\n",
"# This gatherer will show the user (you!) pairs of trajectories and ask it to choose\n",
"# which one is better. It will then use the user's feedback to train the reward network.\n",
"gatherer = preference_comparisons.SynchronousHumanGatherer(video_dir=video_dir)\n",
"\n",
"preference_model = preference_comparisons.PreferenceModel(reward_net)\n",
"reward_trainer = preference_comparisons.BasicRewardTrainer(\n",
" preference_model=preference_model,\n",
" loss=preference_comparisons.CrossEntropyRewardLoss(),\n",
" epochs=3,\n",
" rng=rng,\n",
")\n",
"\n",
"agent = PPO(\n",
" policy=FeedForward32Policy,\n",
" policy_kwargs=dict(\n",
" features_extractor_class=NormalizeFeaturesExtractor,\n",
" features_extractor_kwargs=dict(normalize_class=RunningNorm),\n",
" ),\n",
" env=venv,\n",
" seed=0,\n",
" n_steps=2048 // venv.num_envs,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0003,\n",
" n_epochs=10,\n",
")\n",
"\n",
"trajectory_generator = preference_comparisons.AgentTrainer(\n",
" algorithm=agent,\n",
" reward_fn=reward_net,\n",
" venv=venv,\n",
" exploration_frac=0.0,\n",
" rng=rng,\n",
")\n",
"\n",
"pref_comparisons = preference_comparisons.PreferenceComparisons(\n",
" trajectory_generator,\n",
" reward_net,\n",
" num_iterations=5,\n",
" fragmenter=fragmenter,\n",
" preference_querent=querent,\n",
" preference_gatherer=gatherer,\n",
" reward_trainer=reward_trainer,\n",
" fragment_length=100,\n",
" transition_oversampling=1,\n",
" initial_comparison_frac=0.1,\n",
" allow_variable_horizon=False,\n",
" initial_epoch_multiplier=1,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We're going to train with only 20 comparisons to make it faster for you to evaluate. The videos will appear in-line in this notebook for you to watch, and a text input will appear for you to choose one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pref_comparisons.train(\n",
" total_timesteps=5_000, # For good performance this should be 1_000_000\n",
" total_comparisons=20, # For good performance this should be 5_000\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"From this point onward, this notebook is the same as [the synthetic gatherer notebook](5_train_preference_comparisons.ipynb).\n",
"\n",
"After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n",
"\n",
"\n",
"learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can train an agent, that only sees those learned reward."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
"\n",
"learner = PPO(\n",
" policy=MlpPolicy,\n",
" env=learned_reward_venv,\n",
" seed=0,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0003,\n",
" n_epochs=10,\n",
" n_steps=64,\n",
")\n",
"learner.learn(1000) # Note: set to 100000 to train a proficient expert"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can evaluate it using the original reward."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"reward, _ = evaluate_policy(learner.policy, venv, 10)\n",
"print(reward)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# clean up the videos we made\n",
"import shutil\n",
"\n",
"shutil.rmtree(video_dir)"
]
}
],
"metadata": {
"interpreter": {
"hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"wandb==0.12.21",
"setuptools_scm~=7.0.5",
"pre-commit>=2.20.0",
"types-requests~=2.31.0.1",
"requests-mock~=1.11.0",
]
+ PARALLEL_REQUIRE
+ ATARI_REQUIRE
Expand Down Expand Up @@ -209,6 +211,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"huggingface_sb3~=3.0",
"optuna>=3.0.1",
"datasets>=2.8.0",
"opencv-python", # TODO: specify version
],
tests_require=TESTS_REQUIRE,
extras_require={
Expand Down
Loading