Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
993e059
First draft with somehwat working airl configured using Hydra.
ernestum Apr 17, 2023
1328c2a
Use recursive calls and introduce random number generator config.
ernestum Apr 17, 2023
a72a25d
Split up AILR config in AIRL creation and AIRL run config. Also move …
ernestum Apr 18, 2023
89214d9
Make inline imports to speed up shell completion.
ernestum Apr 19, 2023
edd12c4
Add type annotations, fix typing issues, add comments.
ernestum Apr 19, 2023
c014dfb
Remove env_test.py
ernestum Apr 19, 2023
6c2b3a1
Add code for policy evaluation.
ernestum Apr 20, 2023
fe0bc1d
Move airl configuration to it's own file and restructure the main run…
ernestum Apr 22, 2023
2906d55
Add checkpoint saving.
ernestum Apr 22, 2023
f1366e6
Use the hydra chdir feature to store logs in the output directory.
ernestum Apr 22, 2023
cb92c33
Remove defaults parameter from `register_configs`, introduce a air_ru…
ernestum Apr 24, 2023
b8221ab
Define cartpole and pendulum envs as structured configs.
ernestum Apr 24, 2023
0912138
Ensure PPO on disk does not inherit from PPO and loads from an absolu…
ernestum Apr 24, 2023
a0c0d98
Remove low default number of steps for generated trajectories and mov…
ernestum Apr 24, 2023
1ff9aa1
Introduce default_environment to the register_configs functions, remo…
ernestum Apr 24, 2023
e37fd4c
Move registering the expert policy as a sub-call to registering the t…
ernestum Apr 24, 2023
59c2b08
Update the airl_sweep.yaml
ernestum Apr 24, 2023
244eed0
Update the airl_sweep.yaml
ernestum Apr 24, 2023
ff0028f
Add airl_optuna.yaml
ernestum Apr 24, 2023
ca632a8
Formatting, typing and documentation fixes. Also the implicit seed de…
ernestum Apr 25, 2023
1738822
Swtich from Hydra call to hydra instantiate.
ernestum Apr 26, 2023
5638222
Add type ignore reason.
ernestum Apr 26, 2023
efe3b90
Dont' allow variable horizon for AIRL by default.
ernestum Apr 27, 2023
5915a80
Simplify the class configurations using enums.
ernestum Apr 27, 2023
6760104
Fix bug in type ignore reason.
ernestum Apr 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"tensorboard>=1.14",
"huggingface_sb3>=2.2.1",
"datasets>=2.8.0",
"hydra-core>=1.3.2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add explicit dep on omegaconf (which is used directly in subsequent files).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add it because Hydra is heavily based on Omegaconf (it grew out of that project, see more details here), but it can't hurt to add it explicitly just in case Hydra decides to change the backend.

],
tests_require=TESTS_REQUIRE,
extras_require={
Expand Down
5 changes: 5 additions & 0 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
if self.demo_batch_size % self.demo_minibatch_size != 0:
raise ValueError("Batch size must be a multiple of minibatch size.")
self._demo_data_loader = None
self._demonstrations: Optional[base.AnyTransitions] = None
self._endless_expert_iterator = None
super().__init__(
demonstrations=demonstrations,
Expand Down Expand Up @@ -298,12 +299,16 @@ def reward_test(self) -> reward_nets.RewardNet:
"""Reward used to train policy at "test" time after adversarial training."""

def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
self._demonstrations = demonstrations
self._demo_data_loader = base.make_data_loader(
demonstrations,
self.demo_batch_size,
)
self._endless_expert_iterator = util.endless_iter(self._demo_data_loader)

def get_demonstrations(self) -> Optional[base.AnyTransitions]:
return self._demonstrations

def _next_expert_batch(self) -> Mapping:
assert self._endless_expert_iterator is not None
return next(self._endless_expert_iterator)
Expand Down
1 change: 1 addition & 0 deletions src/imitation_cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Hydra configurations and scripts that form a CLI for imitation."""
105 changes: 105 additions & 0 deletions src/imitation_cli/airl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Config and run configuration for AIRL."""
import dataclasses
import logging
import pathlib
from typing import Any, Dict, Sequence, cast

import hydra
import torch as th
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import MISSING

from imitation.policies import serialize
from imitation_cli.algorithm_configurations import airl as airl_cfg
from imitation_cli.utils import environment as environment_cfg
from imitation_cli.utils import (
policy_evaluation,
randomness,
reward_network,
rl_algorithm,
trajectories,
)


@dataclasses.dataclass
class RunConfig:
"""Config for running AIRL."""

rng: randomness.Config = randomness.Config(seed=0)
total_timesteps: int = int(1e6)
checkpoint_interval: int = 0

environment: environment_cfg.Config = MISSING
airl: airl_cfg.Config = MISSING
evaluation: policy_evaluation.Config = MISSING
# This ensures that the working directory is changed
# to the hydra output dir
hydra: Any = dataclasses.field(default_factory=lambda: dict(job=dict(chdir=True)))


cs = ConfigStore.instance()
environment_cfg.register_configs("environment", "${rng}")
trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}")
rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}")
reward_network.register_configs("airl/reward_net", "${environment}")
policy_evaluation.register_configs("evaluation", "${environment}", "${rng}")

cs.store(
name="airl_run_base",
node=RunConfig(
airl=airl_cfg.Config(
venv="${environment}", # type: ignore[arg-type]
),
),
)


@hydra.main(
version_base=None,
config_path="config",
config_name="airl_run",
)
def run_airl(cfg: RunConfig) -> Dict[str, Any]:
from imitation.algorithms.adversarial import airl
from imitation.data import rollout
from imitation.data.types import TrajectoryWithRew

trainer: airl.AIRL = instantiate(cfg.airl)

checkpoints_path = pathlib.Path("checkpoints")

def save(path: str):
"""Save discriminator and generator."""
# We implement this here and not in Trainer since we do not want to actually
# serialize the whole Trainer (including e.g. expert demonstrations).
save_path = checkpoints_path / path
save_path.mkdir(parents=True, exist_ok=True)

th.save(trainer.reward_train, save_path / "reward_train.pt")
th.save(trainer.reward_test, save_path / "reward_test.pt")
serialize.save_stable_model(save_path / "gen_policy", trainer.gen_algo)

def callback(round_num: int, /) -> None:
if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0:
logging.log(logging.INFO, f"Saving checkpoint at round {round_num}")
save(f"{round_num:05d}")

trainer.train(cfg.total_timesteps, callback)
imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation)

# Save final artifacts.
if cfg.checkpoint_interval >= 0:
logging.log(logging.INFO, "Saving final checkpoint.")
save("final")

return {
"imit_stats": imit_stats,
"expert_stats": rollout.rollout_stats(
cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()),
),
}


if __name__ == "__main__":
run_airl()
1 change: 1 addition & 0 deletions src/imitation_cli/algorithm_configurations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Structured Hydra configuration for Imitation algorithms."""
33 changes: 33 additions & 0 deletions src/imitation_cli/algorithm_configurations/airl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Config for AIRL."""
import dataclasses
from typing import Optional

from omegaconf import MISSING

from imitation_cli.utils import environment as environment_cfg
from imitation_cli.utils import (
optimizer_class,
reward_network,
rl_algorithm,
trajectories,
)


@dataclasses.dataclass
class Config:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit confusing—won't users of this class always have to import it as AIRLConfig or something that disambiguates between the other configs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the convention to not directly import classes, so in a different file we would write:

from imitation_cli.algorithm_configurations import airl
...
my_conf = airl.Config()

"""Config for AIRL."""

_target_: str = "imitation.algorithms.adversarial.airl.AIRL"
venv: environment_cfg.Config = MISSING
demonstrations: trajectories.Config = MISSING
gen_algo: rl_algorithm.Config = MISSING
reward_net: reward_network.Config = MISSING
demo_batch_size: int = 64
n_disc_updates_per_round: int = 2
disc_opt_cls: optimizer_class.Config = optimizer_class.Adam
gen_train_timesteps: Optional[int] = None
gen_replay_buffer_capacity: Optional[int] = None
init_tensorboard: bool = False
init_tensorboard_graph: bool = False
debug_use_ground_truth: bool = False
allow_variable_horizon: bool = False
26 changes: 26 additions & 0 deletions src/imitation_cli/config/airl_optuna.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
defaults:
- airl_run_base
- environment: gym_env
- airl/reward_net: shaped
- airl/gen_algo: ppo
- evaluation: default_evaluation
- airl/demonstrations: generated
- airl/demonstrations/expert_policy: random
- override hydra/sweeper: optuna
- _self_

total_timesteps: 40000
checkpoint_interval: 1

airl:
demo_batch_size: 128
demonstrations:
total_timesteps: 10
allow_variable_horizon: true

hydra:
mode: MULTIRUN
sweeper:
params:
environment: cartpole,pendulum
airl/reward_net: basic,shaped,small_ensemble
19 changes: 19 additions & 0 deletions src/imitation_cli/config/airl_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defaults:
- airl_run_base
- environment: cartpole
- airl/reward_net: shaped
- airl/gen_algo: ppo
- evaluation: default_evaluation
- airl/demonstrations: generated
- airl/demonstrations/expert_policy: random
# - [email protected]_net.environment: pendulum # This is how we inject a different environment
- _self_

total_timesteps: 40000
checkpoint_interval: 1

airl:
demo_batch_size: 128
demonstrations:
total_timesteps: 10
allow_variable_horizon: true
25 changes: 25 additions & 0 deletions src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- airl_run_base
- environment: gym_env
- airl/reward_net: shaped
- airl/gen_algo: ppo
- evaluation: default_evaluation
- airl/demonstrations: generated
- airl/demonstrations/expert_policy: random
- _self_

total_timesteps: 40000
checkpoint_interval: 1

airl:
demo_batch_size: 128
demonstrations:
total_timesteps: 10
allow_variable_horizon: true

hydra:
mode: MULTIRUN
sweeper:
params:
environment: cartpole,pendulum
airl/reward_net: basic,shaped,small_ensemble
1 change: 1 addition & 0 deletions src/imitation_cli/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Configurations to be used as ingredient to algorithm configurations."""
37 changes: 37 additions & 0 deletions src/imitation_cli/utils/activation_function_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Classes for configuring activation functions."""
import dataclasses
from enum import Enum

import torch
from hydra.core.config_store import ConfigStore


class ActivationFunctionClass(Enum):
"""Enum of activation function classes."""

TanH = torch.nn.Tanh
ReLU = torch.nn.ReLU
LeakyReLU = torch.nn.LeakyReLU


@dataclasses.dataclass
class Config:
"""Base class for activation function configs."""

activation_function_class: ActivationFunctionClass
_target_: str = "imitation_cli.utils.activation_function_class.Config.make"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to simplify this sort of thing so that it doesn't have to be repeated in every class?
(e.g. inheriting from a base class with an appropriate target property)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There probably is a super smart™ solution that is entirely incomprehensible to the casual user. This is one of the cases where I decided to accept some level of repetition for the sake of simpler understanding.


@staticmethod
def make(activation_function_class: ActivationFunctionClass) -> type:
return activation_function_class.value


TanH = Config(ActivationFunctionClass.TanH)
ReLU = Config(ActivationFunctionClass.ReLU)
LeakyReLU = Config(ActivationFunctionClass.LeakyReLU)


def register_configs(group: str):
cs = ConfigStore.instance()
for cls in ActivationFunctionClass:
cs.store(group=group, name=cls.name.lower(), node=Config(cls))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Programmatically generating names like this will make it fractionally harder to grep for where things are defined (e.g. if I do a case sensitive search for "tanh" then this file won't come up).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a really good point! In general configuring classes seems to require a lot of boilerplate code the way I do it here so I will try to find a smarter solution. Until I found one, what do you think about just removing .lower()?

66 changes: 66 additions & 0 deletions src/imitation_cli/utils/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Configuration for Gym environments."""
from __future__ import annotations

import dataclasses
import typing
from typing import Optional, Union, cast

if typing.TYPE_CHECKING:
from stable_baselines3.common.vec_env import VecEnv

from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import MISSING

from imitation_cli.utils import randomness


@dataclasses.dataclass
class Config:
"""Configuration for Gym environments."""

_target_: str = "imitation_cli.utils.environment.Config.make"
env_name: str = MISSING # The environment to train on
n_envs: int = 8 # number of environments in VecEnv
# TODO: when setting this to true this is really slow for some reason
parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv
max_episode_steps: int = MISSING # Set to positive int to limit episode horizons
env_make_kwargs: dict = dataclasses.field(
default_factory=dict,
) # The kwargs passed to `spec.make`.
rng: randomness.Config = MISSING

@staticmethod
def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv:
from imitation.util import util

return util.make_vec_env(log_dir=log_dir, **kwargs)


def make_rollout_venv(environment_config: Config) -> VecEnv:
from imitation.data import wrappers

return instantiate(
environment_config,
log_dir=None,
post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)],
)


def register_configs(
group: str,
default_rng: Union[randomness.Config, str] = MISSING,
):
default_rng = cast(randomness.Config, default_rng)
cs = ConfigStore.instance()
cs.store(group=group, name="gym_env", node=Config(rng=default_rng))
cs.store(
group=group,
name="cartpole",
node=Config(env_name="CartPole-v0", max_episode_steps=500, rng=default_rng),
)
cs.store(
group=group,
name="pendulum",
node=Config(env_name="Pendulum-v1", max_episode_steps=500, rng=default_rng),
)
35 changes: 35 additions & 0 deletions src/imitation_cli/utils/feature_extractor_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Register Hydra configs for stable_baselines3 feature extractors."""
import dataclasses
from enum import Enum

import stable_baselines3.common.torch_layers as torch_layers
from hydra.core.config_store import ConfigStore


class FeatureExtractorClass(Enum):
"""Enum of feature extractor classes."""

FlattenExtractor = torch_layers.FlattenExtractor
NatureCNN = torch_layers.NatureCNN


@dataclasses.dataclass
class Config:
"""Base config for stable_baselines3 feature extractors."""

feature_extractor_class: FeatureExtractorClass
_target_: str = "imitation_cli.utils.feature_extractor_class.Config.make"

@staticmethod
def make(feature_extractor_class: FeatureExtractorClass) -> type:
return feature_extractor_class.value


FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor)
NatureCNN = Config(FeatureExtractorClass.NatureCNN)


def register_configs(group: str):
cs = ConfigStore.instance()
for cls in FeatureExtractorClass:
cs.store(group=group, name=cls.name.lower(), node=Config(cls))
Loading