-
Notifications
You must be signed in to change notification settings - Fork 290
Hydra exploration #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Hydra exploration #703
Changes from all commits
993e059
1328c2a
a72a25d
89214d9
edd12c4
c014dfb
6c2b3a1
fe0bc1d
2906d55
f1366e6
cb92c33
b8221ab
0912138
a0c0d98
1ff9aa1
e37fd4c
59c2b08
244eed0
ff0028f
ca632a8
1738822
5638222
efe3b90
5915a80
6760104
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Hydra configurations and scripts that form a CLI for imitation.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| """Config and run configuration for AIRL.""" | ||
| import dataclasses | ||
| import logging | ||
| import pathlib | ||
| from typing import Any, Dict, Sequence, cast | ||
|
|
||
| import hydra | ||
| import torch as th | ||
| from hydra.core.config_store import ConfigStore | ||
| from hydra.utils import instantiate | ||
| from omegaconf import MISSING | ||
|
|
||
| from imitation.policies import serialize | ||
| from imitation_cli.algorithm_configurations import airl as airl_cfg | ||
| from imitation_cli.utils import environment as environment_cfg | ||
| from imitation_cli.utils import ( | ||
| policy_evaluation, | ||
| randomness, | ||
| reward_network, | ||
| rl_algorithm, | ||
| trajectories, | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class RunConfig: | ||
| """Config for running AIRL.""" | ||
|
|
||
| rng: randomness.Config = randomness.Config(seed=0) | ||
| total_timesteps: int = int(1e6) | ||
| checkpoint_interval: int = 0 | ||
|
|
||
| environment: environment_cfg.Config = MISSING | ||
| airl: airl_cfg.Config = MISSING | ||
| evaluation: policy_evaluation.Config = MISSING | ||
| # This ensures that the working directory is changed | ||
| # to the hydra output dir | ||
| hydra: Any = dataclasses.field(default_factory=lambda: dict(job=dict(chdir=True))) | ||
|
|
||
|
|
||
| cs = ConfigStore.instance() | ||
| environment_cfg.register_configs("environment", "${rng}") | ||
| trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}") | ||
| rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}") | ||
| reward_network.register_configs("airl/reward_net", "${environment}") | ||
| policy_evaluation.register_configs("evaluation", "${environment}", "${rng}") | ||
|
|
||
| cs.store( | ||
| name="airl_run_base", | ||
| node=RunConfig( | ||
| airl=airl_cfg.Config( | ||
| venv="${environment}", # type: ignore[arg-type] | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| @hydra.main( | ||
| version_base=None, | ||
| config_path="config", | ||
| config_name="airl_run", | ||
| ) | ||
| def run_airl(cfg: RunConfig) -> Dict[str, Any]: | ||
| from imitation.algorithms.adversarial import airl | ||
| from imitation.data import rollout | ||
| from imitation.data.types import TrajectoryWithRew | ||
|
|
||
| trainer: airl.AIRL = instantiate(cfg.airl) | ||
|
|
||
| checkpoints_path = pathlib.Path("checkpoints") | ||
|
|
||
| def save(path: str): | ||
| """Save discriminator and generator.""" | ||
| # We implement this here and not in Trainer since we do not want to actually | ||
| # serialize the whole Trainer (including e.g. expert demonstrations). | ||
| save_path = checkpoints_path / path | ||
| save_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| th.save(trainer.reward_train, save_path / "reward_train.pt") | ||
| th.save(trainer.reward_test, save_path / "reward_test.pt") | ||
| serialize.save_stable_model(save_path / "gen_policy", trainer.gen_algo) | ||
|
|
||
| def callback(round_num: int, /) -> None: | ||
| if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: | ||
| logging.log(logging.INFO, f"Saving checkpoint at round {round_num}") | ||
| save(f"{round_num:05d}") | ||
|
|
||
| trainer.train(cfg.total_timesteps, callback) | ||
| imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation) | ||
|
|
||
| # Save final artifacts. | ||
| if cfg.checkpoint_interval >= 0: | ||
| logging.log(logging.INFO, "Saving final checkpoint.") | ||
| save("final") | ||
|
|
||
| return { | ||
| "imit_stats": imit_stats, | ||
| "expert_stats": rollout.rollout_stats( | ||
| cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()), | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_airl() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Structured Hydra configuration for Imitation algorithms.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| """Config for AIRL.""" | ||
| import dataclasses | ||
| from typing import Optional | ||
|
|
||
| from omegaconf import MISSING | ||
|
|
||
| from imitation_cli.utils import environment as environment_cfg | ||
| from imitation_cli.utils import ( | ||
| optimizer_class, | ||
| reward_network, | ||
| rl_algorithm, | ||
| trajectories, | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Config: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is a bit confusing—won't users of this class always have to import it as
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have the convention to not directly import classes, so in a different file we would write: from imitation_cli.algorithm_configurations import airl
...
my_conf = airl.Config() |
||
| """Config for AIRL.""" | ||
|
|
||
| _target_: str = "imitation.algorithms.adversarial.airl.AIRL" | ||
| venv: environment_cfg.Config = MISSING | ||
| demonstrations: trajectories.Config = MISSING | ||
| gen_algo: rl_algorithm.Config = MISSING | ||
| reward_net: reward_network.Config = MISSING | ||
| demo_batch_size: int = 64 | ||
| n_disc_updates_per_round: int = 2 | ||
| disc_opt_cls: optimizer_class.Config = optimizer_class.Adam | ||
| gen_train_timesteps: Optional[int] = None | ||
| gen_replay_buffer_capacity: Optional[int] = None | ||
| init_tensorboard: bool = False | ||
| init_tensorboard_graph: bool = False | ||
| debug_use_ground_truth: bool = False | ||
| allow_variable_horizon: bool = False | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| defaults: | ||
| - airl_run_base | ||
| - environment: gym_env | ||
| - airl/reward_net: shaped | ||
| - airl/gen_algo: ppo | ||
| - evaluation: default_evaluation | ||
| - airl/demonstrations: generated | ||
| - airl/demonstrations/expert_policy: random | ||
| - override hydra/sweeper: optuna | ||
| - _self_ | ||
|
|
||
| total_timesteps: 40000 | ||
| checkpoint_interval: 1 | ||
|
|
||
| airl: | ||
| demo_batch_size: 128 | ||
| demonstrations: | ||
| total_timesteps: 10 | ||
| allow_variable_horizon: true | ||
|
|
||
| hydra: | ||
| mode: MULTIRUN | ||
| sweeper: | ||
| params: | ||
| environment: cartpole,pendulum | ||
| airl/reward_net: basic,shaped,small_ensemble |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| defaults: | ||
| - airl_run_base | ||
| - environment: cartpole | ||
| - airl/reward_net: shaped | ||
| - airl/gen_algo: ppo | ||
| - evaluation: default_evaluation | ||
| - airl/demonstrations: generated | ||
| - airl/demonstrations/expert_policy: random | ||
| # - [email protected]_net.environment: pendulum # This is how we inject a different environment | ||
| - _self_ | ||
|
|
||
| total_timesteps: 40000 | ||
| checkpoint_interval: 1 | ||
|
|
||
| airl: | ||
| demo_batch_size: 128 | ||
| demonstrations: | ||
| total_timesteps: 10 | ||
| allow_variable_horizon: true |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| defaults: | ||
| - airl_run_base | ||
| - environment: gym_env | ||
| - airl/reward_net: shaped | ||
| - airl/gen_algo: ppo | ||
| - evaluation: default_evaluation | ||
| - airl/demonstrations: generated | ||
| - airl/demonstrations/expert_policy: random | ||
| - _self_ | ||
|
|
||
| total_timesteps: 40000 | ||
| checkpoint_interval: 1 | ||
|
|
||
| airl: | ||
| demo_batch_size: 128 | ||
| demonstrations: | ||
| total_timesteps: 10 | ||
| allow_variable_horizon: true | ||
|
|
||
| hydra: | ||
| mode: MULTIRUN | ||
| sweeper: | ||
| params: | ||
| environment: cartpole,pendulum | ||
| airl/reward_net: basic,shaped,small_ensemble |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Configurations to be used as ingredient to algorithm configurations.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| """Classes for configuring activation functions.""" | ||
| import dataclasses | ||
| from enum import Enum | ||
|
|
||
| import torch | ||
| from hydra.core.config_store import ConfigStore | ||
|
|
||
|
|
||
| class ActivationFunctionClass(Enum): | ||
| """Enum of activation function classes.""" | ||
|
|
||
| TanH = torch.nn.Tanh | ||
| ReLU = torch.nn.ReLU | ||
| LeakyReLU = torch.nn.LeakyReLU | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Config: | ||
| """Base class for activation function configs.""" | ||
|
|
||
| activation_function_class: ActivationFunctionClass | ||
| _target_: str = "imitation_cli.utils.activation_function_class.Config.make" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to simplify this sort of thing so that it doesn't have to be repeated in every class?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There probably is a super smart™ solution that is entirely incomprehensible to the casual user. This is one of the cases where I decided to accept some level of repetition for the sake of simpler understanding. |
||
|
|
||
| @staticmethod | ||
| def make(activation_function_class: ActivationFunctionClass) -> type: | ||
| return activation_function_class.value | ||
|
|
||
|
|
||
| TanH = Config(ActivationFunctionClass.TanH) | ||
| ReLU = Config(ActivationFunctionClass.ReLU) | ||
| LeakyReLU = Config(ActivationFunctionClass.LeakyReLU) | ||
|
|
||
|
|
||
| def register_configs(group: str): | ||
| cs = ConfigStore.instance() | ||
| for cls in ActivationFunctionClass: | ||
| cs.store(group=group, name=cls.name.lower(), node=Config(cls)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Programmatically generating names like this will make it fractionally harder to grep for where things are defined (e.g. if I do a case sensitive search for "tanh" then this file won't come up).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a really good point! In general configuring classes seems to require a lot of boilerplate code the way I do it here so I will try to find a smarter solution. Until I found one, what do you think about just removing |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| """Configuration for Gym environments.""" | ||
| from __future__ import annotations | ||
|
|
||
| import dataclasses | ||
| import typing | ||
| from typing import Optional, Union, cast | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| from stable_baselines3.common.vec_env import VecEnv | ||
|
|
||
| from hydra.core.config_store import ConfigStore | ||
| from hydra.utils import instantiate | ||
| from omegaconf import MISSING | ||
|
|
||
| from imitation_cli.utils import randomness | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Config: | ||
| """Configuration for Gym environments.""" | ||
|
|
||
| _target_: str = "imitation_cli.utils.environment.Config.make" | ||
| env_name: str = MISSING # The environment to train on | ||
| n_envs: int = 8 # number of environments in VecEnv | ||
| # TODO: when setting this to true this is really slow for some reason | ||
| parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv | ||
| max_episode_steps: int = MISSING # Set to positive int to limit episode horizons | ||
| env_make_kwargs: dict = dataclasses.field( | ||
| default_factory=dict, | ||
| ) # The kwargs passed to `spec.make`. | ||
| rng: randomness.Config = MISSING | ||
|
|
||
| @staticmethod | ||
| def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv: | ||
| from imitation.util import util | ||
|
|
||
| return util.make_vec_env(log_dir=log_dir, **kwargs) | ||
|
|
||
|
|
||
| def make_rollout_venv(environment_config: Config) -> VecEnv: | ||
| from imitation.data import wrappers | ||
|
|
||
| return instantiate( | ||
| environment_config, | ||
| log_dir=None, | ||
| post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], | ||
| ) | ||
|
|
||
|
|
||
| def register_configs( | ||
| group: str, | ||
| default_rng: Union[randomness.Config, str] = MISSING, | ||
| ): | ||
| default_rng = cast(randomness.Config, default_rng) | ||
| cs = ConfigStore.instance() | ||
| cs.store(group=group, name="gym_env", node=Config(rng=default_rng)) | ||
| cs.store( | ||
| group=group, | ||
| name="cartpole", | ||
| node=Config(env_name="CartPole-v0", max_episode_steps=500, rng=default_rng), | ||
| ) | ||
| cs.store( | ||
| group=group, | ||
| name="pendulum", | ||
| node=Config(env_name="Pendulum-v1", max_episode_steps=500, rng=default_rng), | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| """Register Hydra configs for stable_baselines3 feature extractors.""" | ||
| import dataclasses | ||
| from enum import Enum | ||
|
|
||
| import stable_baselines3.common.torch_layers as torch_layers | ||
| from hydra.core.config_store import ConfigStore | ||
|
|
||
|
|
||
| class FeatureExtractorClass(Enum): | ||
| """Enum of feature extractor classes.""" | ||
|
|
||
| FlattenExtractor = torch_layers.FlattenExtractor | ||
| NatureCNN = torch_layers.NatureCNN | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Config: | ||
| """Base config for stable_baselines3 feature extractors.""" | ||
|
|
||
| feature_extractor_class: FeatureExtractorClass | ||
| _target_: str = "imitation_cli.utils.feature_extractor_class.Config.make" | ||
|
|
||
| @staticmethod | ||
| def make(feature_extractor_class: FeatureExtractorClass) -> type: | ||
| return feature_extractor_class.value | ||
|
|
||
|
|
||
| FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor) | ||
| NatureCNN = Config(FeatureExtractorClass.NatureCNN) | ||
|
|
||
|
|
||
| def register_configs(group: str): | ||
| cs = ConfigStore.instance() | ||
| for cls in FeatureExtractorClass: | ||
| cs.store(group=group, name=cls.name.lower(), node=Config(cls)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add explicit dep on omegaconf (which is used directly in subsequent files).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not add it because Hydra is heavily based on Omegaconf (it grew out of that project, see more details here), but it can't hurt to add it explicitly just in case Hydra decides to change the backend.