-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DMC training and tutorial benchmark improvements (#6)
## DMC trainig 1. An example for training on environments and tasks from DMC. 2. A README for examples, now explains the DMC train example. ## Improvements in the benchmarking examples and tutorials 1. Improvements in`tutorial_benchmark.ipynb` 2. Removal of `examples/humenv_evaluation.py`, as it is redundant with `tutorial_benchmark.ipynb`
- Loading branch information
1 parent
ff0d9a1
commit b947ca6
Showing
4 changed files
with
429 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Examples | ||
|
||
We provide a few examples on how to use the Meta Motivo repository. | ||
|
||
## Offline training with ExoRL datasets | ||
|
||
[ExoRL](https://github.com/denisyarats/exorl) has been widely used to train offline algorithms. We provide the code for training FB on standard domains such as `walker`, `cheetah`, `quadruped` and `pointmass`. We use the standard tasks in `dm_control`, but you can easily update the script to run the full set of tasks defined in `ExoRL` or in the paper [Fast Imitation via Behavior Foundation Models](https://openreview.net/forum?id=qnWtw3l0jb). We will provide more details below. | ||
|
||
To use the provided script you can simply run from terminal | ||
|
||
```python | ||
python fb_train_dmc.py --domain_name walker --dataset_root <root to exorl data> | ||
``` | ||
|
||
The standard folder structure of ExoRL is `<prefix>/datasets/${DOMAIN}/${ALGO}/buffer` so we expect `dataset_root=<prefix>/datasets`. Since the original creation of ExORL, mujoco has seen many updates. To rerun all the actions and collect a physics consistent data, you may optionally replay the trajectories. We refer to [https://github.com/facebookresearch/mtm/tree/main/research/exorl](https://github.com/facebookresearch/mtm/tree/main/research/exorl) for this. | ||
|
||
If you want to run auxiliary tasks and domains such as `walker_flip` or `pointmass` we suggest to download the files from [https://github.com/facebookresearch/offline_rl/tree/main/src/dmc_tasks](https://github.com/facebookresearch/offline_rl/tree/main/src/dmc_tasks) into `examples/dmc_tasks`. You can thus simply modify `fb_train_dmc.py` as follows: | ||
|
||
- add import | ||
``` | ||
from dmc_tasks import dmc | ||
``` | ||
- add new tasks | ||
``` | ||
ALL_TASKS = { | ||
"walker": ["walk", "run", "stand", "flip", "spin"], | ||
"cheetah": ["walk", "run", "walk_backward", "run_backward"], | ||
"pointmass": ["reach_top_left", "reach_top_right", "reach_bottom_right", "reach_bottom_left", "loop", "square", "fast_slow"], | ||
"quadruped": ["jump", "walk", "run", "stand"], | ||
} | ||
``` | ||
- use `dmc.make` for environment creation. For example, replace `suite.load(domain_name=self.cfg.domain_name,task_name=task,environment_kwargs={"flat_observation": True},)` with `dmc.make(f"{self.cfg.domain_name}_{task}")`. | ||
- This changes the way of getting the observation from `time_step.observation["observations"]` to simply `time_step.observation`. Update the file accordingly. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,346 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the CC BY-NC 4.0 license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
import torch | ||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
import numpy as np | ||
import dataclasses | ||
from metamotivo.buffers.buffers import DictBuffer | ||
from metamotivo.fb import FBAgent, FBAgentConfig | ||
from metamotivo.nn_models import eval_mode | ||
from tqdm import tqdm | ||
import time | ||
from dm_control import suite | ||
import random | ||
from pathlib import Path | ||
import wandb | ||
import json | ||
from typing import List | ||
import mujoco | ||
import warnings | ||
import tyro | ||
|
||
ALL_TASKS = { | ||
"walker": [ | ||
"walk", | ||
"run", | ||
"stand", | ||
], | ||
"cheetah": ["walk", "run"], | ||
"quadruped": ["walk", "run"], | ||
} | ||
|
||
|
||
def create_agent( | ||
domain_name="walker", | ||
task_name="walk", | ||
device="cpu", | ||
compile=False, | ||
cudagraphs=False, | ||
) -> FBAgent: | ||
if domain_name not in ["walker", "pointmass", "cheetah", "quadruped"]: | ||
raise RuntimeError('FB configuration defined only for "walker", "pointmass", "cheetah", "quadruped"') | ||
env = suite.load( | ||
domain_name=domain_name, | ||
task_name=task_name, | ||
environment_kwargs={"flat_observation": True}, | ||
) | ||
|
||
agent_config = FBAgentConfig() | ||
agent_config.model.obs_dim = env.observation_spec()["observations"].shape[0] | ||
agent_config.model.action_dim = env.action_spec().shape[0] | ||
agent_config.model.device = device | ||
agent_config.model.norm_obs = False | ||
agent_config.model.seq_length = 1 | ||
agent_config.train.batch_size = 1024 | ||
# archi | ||
if domain_name in ["walker", "pointmass"]: | ||
agent_config.model.archi.z_dim = 100 | ||
else: | ||
agent_config.model.archi.z_dim = 50 | ||
agent_config.model.archi.b.norm = True | ||
agent_config.model.archi.norm_z = True | ||
agent_config.model.archi.b.hidden_dim = 256 | ||
agent_config.model.archi.f.hidden_dim = 1024 | ||
agent_config.model.archi.actor.hidden_dim = 1024 | ||
agent_config.model.archi.f.hidden_layers = 1 | ||
agent_config.model.archi.actor.hidden_layers = 1 | ||
agent_config.model.archi.b.hidden_layers = 2 | ||
# optim | ||
if domain_name == "pointmass": | ||
agent_config.train.lr_f = 1e-4 | ||
agent_config.train.lr_b = 1e-6 | ||
agent_config.train.lr_actor = 1e-6 | ||
else: | ||
agent_config.train.lr_f = 1e-4 | ||
agent_config.train.lr_b = 1e-4 | ||
agent_config.train.lr_actor = 1e-4 | ||
agent_config.train.ortho_coef = 1 | ||
agent_config.train.train_goal_ratio = 0.5 | ||
agent_config.train.fb_pessimism_penalty = 0 | ||
agent_config.train.actor_pessimism_penalty = 0.5 | ||
|
||
if domain_name == "pointmass": | ||
agent_config.train.discount = 0.99 | ||
else: | ||
agent_config.train.discount = 0.98 | ||
agent_config.compile = compile | ||
agent_config.cudagraphs = cudagraphs | ||
|
||
return agent_config | ||
|
||
|
||
def load_data(dataset_path, expl_agent, domain_name, num_episodes=1): | ||
path = Path(dataset_path) / f"{domain_name}/{expl_agent}/buffer" | ||
print(f"Data path: {path}") | ||
storage = { | ||
"observation": [], | ||
"action": [], | ||
"physics": [], | ||
"next": {"observation": [], "terminated": [], "physics": []}, | ||
} | ||
files = list(path.glob("*.npz")) | ||
num_episodes = min(num_episodes, len(files)) | ||
for i in tqdm(range(num_episodes)): | ||
f = files[i] | ||
data = np.load(str(f)) | ||
storage["observation"].append(data["observation"][:-1].astype(np.float32)) | ||
storage["action"].append(data["action"][1:].astype(np.float32)) | ||
storage["next"]["observation"].append(data["observation"][1:].astype(np.float32)) | ||
storage["next"]["terminated"].append(np.array(1 - data["discount"][1:], dtype=np.bool)) | ||
storage["physics"].append(data["physics"][:-1]) | ||
storage["next"]["physics"].append(data["physics"][1:]) | ||
|
||
for k in storage: | ||
if k == "next": | ||
for k1 in storage[k]: | ||
storage[k][k1] = np.concat(storage[k][k1]) | ||
else: | ||
storage[k] = np.concat(storage[k]) | ||
return storage | ||
|
||
|
||
def set_seed_everywhere(seed): | ||
torch.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
|
||
|
||
@dataclasses.dataclass | ||
class TrainConfig: | ||
dataset_root: str | ||
seed: int = 0 | ||
domain_name: str = "walker" | ||
task_name: str | None = None | ||
dataset_expl_agent: str = "rnd" | ||
num_train_steps: int = 3_000_000 | ||
load_n_episodes: int = 5_000 | ||
log_every_updates: int = 10_000 | ||
work_dir: str | None = None | ||
|
||
checkpoint_every_steps: int = 1_000_000 | ||
|
||
# eval | ||
num_eval_episodes: int = 10 | ||
num_inference_samples: int = 50_000 | ||
eval_every_steps: int = 100_000 | ||
eval_tasks: List[str] | None = None | ||
|
||
# misc | ||
compile: bool = False | ||
cudagraphs: bool = False | ||
device: str = "cuda" | ||
|
||
# WANDB | ||
use_wandb: bool = False | ||
wandb_ename: str | None = None | ||
wandb_gname: str | None = None | ||
wandb_pname: str | None = "fb_train_dmc" | ||
wandb_name_prefix: str | None = None | ||
|
||
def __post_init__(self): | ||
if self.eval_tasks is None: | ||
self.eval_tasks = ALL_TASKS[self.domain_name] | ||
|
||
|
||
class Workspace: | ||
def __init__(self, cfg: TrainConfig, agent_cfg: FBAgentConfig) -> None: | ||
self.cfg = cfg | ||
self.agent_cfg = agent_cfg | ||
if self.cfg.work_dir is None: | ||
import string | ||
|
||
tmp_name = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) | ||
self.work_dir = Path.cwd() / "tmp_fbcpr" / tmp_name | ||
self.cfg.work_dir = str(self.work_dir) | ||
else: | ||
self.work_dir = Path(self.cfg.work_dir) | ||
self.work_dir = Path(self.work_dir) | ||
self.work_dir.mkdir(exist_ok=True, parents=True) | ||
print(f"working dir: {self.work_dir}") | ||
|
||
self.agent = FBAgent(**dataclasses.asdict(self.agent_cfg)) | ||
set_seed_everywhere(self.cfg.seed) | ||
|
||
if self.cfg.use_wandb: | ||
exp_name = "fb" | ||
wandb_name = exp_name | ||
if self.cfg.wandb_name_prefix: | ||
wandb_name = f"{self.cfg.wandb_name_prefix}_{exp_name}" | ||
# fmt: off | ||
wandb_config = dataclasses.asdict(self.cfg) | ||
wandb.init(entity=self.cfg.wandb_ename, project=self.cfg.wandb_pname, | ||
group=self.cfg.agent.name if self.cfg.wandb_gname is None else self.cfg.wandb_gname, name=wandb_name, # mode="disabled", | ||
config=wandb_config) # type: ignore | ||
# fmt: on | ||
|
||
with (self.work_dir / "config.json").open("w") as f: | ||
json.dump(dataclasses.asdict(self.cfg), f, indent=4) | ||
|
||
def train(self): | ||
self.start_time = time.time() | ||
self.train_offline() | ||
|
||
def train_offline(self) -> None: | ||
self.replay_buffer = {} | ||
# LOAD DATA FROM EXORL | ||
data = load_data( | ||
self.cfg.dataset_root, | ||
self.cfg.dataset_expl_agent, | ||
self.cfg.domain_name, | ||
self.cfg.load_n_episodes, | ||
) | ||
self.replay_buffer = {"train": DictBuffer(capacity=data["observation"].shape[0], device=self.agent.device)} | ||
self.replay_buffer["train"].extend(data) | ||
print(self.replay_buffer["train"]) | ||
del data | ||
|
||
total_metrics = None | ||
fps_start_time = time.time() | ||
for t in tqdm(range(0, int(self.cfg.num_train_steps))): | ||
if t % self.cfg.eval_every_steps == 0: | ||
self.eval(t) | ||
|
||
# torch.compiler.cudagraph_mark_step_begin() | ||
metrics = self.agent.update(self.replay_buffer, t) | ||
|
||
# we need to copy tensors returned by a cudagraph module | ||
if total_metrics is None: | ||
total_metrics = {k: metrics[k].clone() for k in metrics.keys()} | ||
else: | ||
total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()} | ||
|
||
if t % self.cfg.log_every_updates == 0: | ||
m_dict = {} | ||
for k in sorted(list(total_metrics.keys())): | ||
tmp = total_metrics[k] / (1 if t == 0 else self.cfg.log_every_updates) | ||
m_dict[k] = np.round(tmp.mean().item(), 6) | ||
m_dict["duration"] = time.time() - self.start_time | ||
m_dict["FPS"] = (1 if t == 0 else self.cfg.log_every_updates) / (time.time() - fps_start_time) | ||
if self.cfg.use_wandb: | ||
wandb.log( | ||
{f"train/{k}": v for k, v in m_dict.items()}, | ||
step=t, | ||
) | ||
print(m_dict) | ||
total_metrics = None | ||
fps_start_time = time.time() | ||
if t % self.cfg.checkpoint_every_steps == 0: | ||
self.agent.save(str(self.work_dir / "checkpoint")) | ||
self.agent.save(str(self.work_dir / "checkpoint")) | ||
return | ||
|
||
def eval(self, t): | ||
for task in self.cfg.eval_tasks: | ||
z = self.reward_inference(task).reshape(1, -1) | ||
eval_env = suite.load( | ||
domain_name=self.cfg.domain_name, | ||
task_name=task, | ||
environment_kwargs={"flat_observation": True}, | ||
) | ||
num_ep = self.cfg.num_eval_episodes | ||
total_reward = np.zeros((num_ep,), dtype=np.float64) | ||
for ep in range(num_ep): | ||
time_step = eval_env.reset() | ||
while not time_step.last(): | ||
with torch.no_grad(), eval_mode(self.agent._model): | ||
obs = torch.tensor( | ||
time_step.observation["observations"].reshape(1, -1), | ||
device=self.agent.device, | ||
dtype=torch.float32, | ||
) | ||
action = self.agent.act(obs=obs, z=z, mean=True).cpu().numpy() | ||
time_step = eval_env.step(action) | ||
total_reward[ep] += time_step.reward | ||
m_dict = { | ||
"reward": np.mean(total_reward), | ||
"reward#std": np.std(total_reward), | ||
} | ||
if self.cfg.use_wandb: | ||
wandb.log( | ||
{f"{task}/{k}": v for k, v in m_dict.items()}, | ||
step=t, | ||
) | ||
m_dict["task"] = task | ||
print(m_dict) | ||
|
||
def reward_inference(self, task) -> torch.Tensor: | ||
env = suite.load( | ||
domain_name=self.cfg.domain_name, | ||
task_name=task, | ||
environment_kwargs={"flat_observation": True}, | ||
) | ||
num_samples = self.cfg.num_inference_samples | ||
batch = self.replay_buffer["train"].sample(num_samples) | ||
rewards = [] | ||
for i in range(num_samples): | ||
with env._physics.reset_context(): | ||
env._physics.set_state(batch["next"]["physics"][i].cpu().numpy()) | ||
env._physics.set_control(batch["action"][i].cpu().detach().numpy()) | ||
mujoco.mj_forward(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member | ||
mujoco.mj_fwdPosition(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member | ||
mujoco.mj_sensorVel(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member | ||
mujoco.mj_subtreeVel(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member | ||
rewards.append(env._task.get_reward(env._physics)) | ||
rewards = np.array(rewards).reshape(-1, 1) | ||
z = self.agent._model.reward_inference( | ||
next_obs=batch["next"]["observation"], | ||
reward=torch.tensor(rewards, dtype=torch.float32, device=self.agent.device), | ||
) | ||
return z | ||
|
||
|
||
if __name__ == "__main__": | ||
config = tyro.cli(TrainConfig) | ||
|
||
warnings.warn( | ||
"Since the original creation of ExORL, mujoco has seen many updates. To rerun all the actions and collect a physics consistent data, you may optionally use the update_data.py utility from MTM (https://github.com/facebookresearch/mtm/tree/main/research/exorl)." | ||
) | ||
if config.task_name is None: | ||
if config.domain_name == "walker": | ||
config.task_name = "walk" | ||
elif config.domain_name == "cheetah": | ||
config.task_name = "run" | ||
elif config.domain_name == "pointmass": | ||
config.task_name = "reach_top_left" | ||
elif config.domain_name == "quadruped": | ||
config.task_name = "run" | ||
else: | ||
raise RuntimeError("Unsupported domain, you need to specify task_name") | ||
agent_config = create_agent( | ||
domain_name=config.domain_name, | ||
task_name=config.task_name, | ||
device=config.device, | ||
compile=config.compile, | ||
cudagraphs=config.cudagraphs, | ||
) | ||
|
||
ws = Workspace(config, agent_cfg=agent_config) | ||
ws.train() |
Oops, something went wrong.