diff --git a/README.md b/README.md index 82643b9..c93cb91 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ This repository provides a PyTorch implementation and pre-trained models for Met ### Features - We provide [**6** pretrained FB-CPR](https://huggingface.co/collections/facebook/meta-motivo-6757761e8fd4a032466fd129) models for controlling the humanoid model defined in [HumEnv](https://github.com/facebookresearch/HumEnv/). -- **Fully reproducible** scripts for evaluating the model in HumEnv -- **Training code for FB and FB-CPR algorithms** +- **Fully reproducible** scripts for evaluating the model in HumEnv. +- **Fully reproducible** [FB-CPR training code in HumEnv](examples/fbcpr_train_humenv.py) for the full results in the paper, and [FB training code in DMC](examples/fb_train_dmc.py) for faster experimentation. # Installation diff --git a/examples/README.md b/examples/README.md index c737b62..bab8ecd 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,13 +2,13 @@ We provide a few examples on how to use the Meta Motivo repository. -## Offline training with ExoRL datasets +## FB: Offline training with ExoRL datasets [ExoRL](https://github.com/denisyarats/exorl) has been widely used to train offline algorithms. We provide the code for training FB on standard domains such as `walker`, `cheetah`, `quadruped` and `pointmass`. We use the standard tasks in `dm_control`, but you can easily update the script to run the full set of tasks defined in `ExoRL` or in the paper [Fast Imitation via Behavior Foundation Models](https://openreview.net/forum?id=qnWtw3l0jb). We will provide more details below. To use the provided script you can simply run from terminal -```python +```bash python fb_train_dmc.py --domain_name walker --dataset_root ``` @@ -31,3 +31,32 @@ ALL_TASKS = { ``` - use `dmc.make` for environment creation. For example, replace `suite.load(domain_name=self.cfg.domain_name,task_name=task,environment_kwargs={"flat_observation": True},)` with `dmc.make(f"{self.cfg.domain_name}_{task}")`. - This changes the way of getting the observation from `time_step.observation["observations"]` to simply `time_step.observation`. Update the file accordingly. + + +## FB-CPR: Online training with HumEnv + +We provide a complete code for training FB-CPR as described in the paper [Zero-Shot Whole-Body Humanoid Control via Behavioral Foundation Models](https://ai.meta.com/research/publications/zero-shot-whole-body-humanoid-control-via-behavioral-foundation-models/). + +**IMPORTANT!** We assume you have already preprocessed the AMASS motions as described [here](https://github.com/facebookresearch/humenv/tree/main/data_preparation). In addition, we assume you also downloaded the `test_train_split` sub-folder. + +The script is setup with the S configuration (i.e., paper configuration) and can be run by simply calling + +```bash +python fbcpr_train_humenv.py --compile --motions test_train_split/large1_small1_train_0.1.txt --motions_root --prioritization +``` + +There are several parameters that can be changed to do evaluation more modular, checkpoint the models, etc. We refer to the code for more details. + +If you would like to train our largest model (the one deployed in the [demo](https://metamotivo.metademolab.com/)), replace the following line + +``` +model, hidden_dim, hidden_layers = "simple", 1024, 2 +``` + +with + +``` +model, hidden_dim, hidden_layers = "residual", 2048, 12 +``` + +NOTE: we recommend that you use compile=True on a A100 GPU or better, as otherwise training can be very slow. diff --git a/examples/fbcpr_train_humenv.py b/examples/fbcpr_train_humenv.py new file mode 100644 index 0000000..742d613 --- /dev/null +++ b/examples/fbcpr_train_humenv.py @@ -0,0 +1,479 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the CC BY-NC 4.0 license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os + +os.environ["OMP_NUM_THREADS"] = "1" + +import torch + +torch.set_float32_matmul_precision("high") + +import collections +import dataclasses +import json +import numbers +import random +import time +from pathlib import Path +from typing import List + +import gymnasium +import humenv +import numpy as np +import tyro +from gymnasium.wrappers import TimeAwareObservation +from humenv import make_humenv +from humenv.bench import ( + RewardEvaluation, + TrackingEvaluation, +) +from humenv.misc.motionlib import canonicalize, load_episode_based_h5 +from packaging.version import Version +from tqdm import tqdm + +import wandb +from metamotivo.buffers.buffers import DictBuffer, TrajectoryBuffer +from metamotivo.fb_cpr import FBcprAgent, FBcprAgentConfig +from metamotivo.wrappers.humenvbench import RewardWrapper, TrackingWrapper + +if Version(humenv.__version__) < Version("0.1.2"): + raise RuntimeError("This script requires humenv>=0.1.2") +if Version(gymnasium.__version__) < Version("1.0"): + raise RuntimeError("This script requires gymnasium>=1.0") + + +def set_seed_everywhere(seed): + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def load_expert_trajectories(motions: str | Path, motions_root: str | Path, device: str) -> TrajectoryBuffer: + with open(motions, "r") as txtf: + h5files = [el.strip().replace(" ", "") for el in txtf.readlines()] + episodes = [] + for h5 in tqdm(h5files, leave=False): + h5 = canonicalize(h5, base_path=motions_root) + _ep = load_episode_based_h5(h5, keys=None) + for el in _ep: + el["observation"] = el["observation"].astype(np.float32) + del el["file_name"] + episodes.extend(_ep) + buffer = TrajectoryBuffer( + capacity=len(episodes), + seq_length=agent_config.model.seq_length, + device=device, + ) + buffer.extend(episodes) + return buffer + + +@dataclasses.dataclass +class TrainConfig: + seed: int = 0 + motions: str = "" + motions_root: str = "" + buffer_size: int = 5_000_000 + online_parallel_envs: int = 50 + log_every_updates: int = 100_000 + work_dir: str | None = None + num_env_steps: int = 30_000_000 + update_agent_every: int | None = None + num_seed_steps: int | None = None + num_agent_updates: int | None = None + checkpoint_every_steps: int = 5_000_000 + prioritization: bool = False + prioritization_min_val: float = 0.5 + prioritization_max_val: float = 5 + prioritization_scale: float = 2 + + # WANDB + use_wandb: bool = False + wandb_ename: str | None = None + wandb_gname: str | None = None + wandb_pname: str | None = "fbcpr_humenv" + + # misc + compile: bool = False + cudagraphs: bool = False + device: str = "cuda" + buffer_device: str = "cpu" + + # eval + evaluate: bool = False + eval_every_steps: int = 1_000_000 + reward_eval_num_envs: int = 5 + reward_eval_num_eval_episodes: int = 10 + reward_eval_num_inference_samples: int = 50_000 + reward_eval_tasks: List[str] | None = None + + tracking_eval_num_envs: int = 60 + tracking_eval_motions: str | None = None + tracking_eval_motions_root: str | None = None + + def __post_init__(self): + if self.reward_eval_tasks is None: + # this is just a subset of the tasks available in humenv + self.reward_eval_tasks = [ + "move-ego-0-0", + "jump-2", + "move-ego-0-2", + "move-ego-90-2", + "move-ego-180-2", + "rotate-x-5-0.8", + "rotate-y-5-0.8", + "rotate-z-5-0.8" + ] + if self.update_agent_every is None: + self.update_agent_every = 10 * self.online_parallel_envs + if self.num_seed_steps is None: + self.num_seed_steps = 1000 * self.online_parallel_envs + if self.num_agent_updates is None: + self.num_agent_updates = self.online_parallel_envs + if self.prioritization: + # NOTE: when using prioritization train and eval motions must match + self.tracking_eval_motions = self.motions + self.tracking_eval_motions_root = self.motions_root + self.evaluate = True + + +class Workspace: + def __init__(self, cfg: TrainConfig, agent_cfg: FBcprAgentConfig) -> None: + self.cfg = cfg + self.agent_cfg = agent_cfg + if self.cfg.work_dir is None: + import string + + tmp_name = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + self.work_dir = Path.cwd() / "tmp_fbcpr" / tmp_name + self.cfg.work_dir = str(self.work_dir) + else: + self.work_dir = self.cfg.work_dir + print(f"Workdir: {self.work_dir}") + self.work_dir = Path(self.work_dir) + self.work_dir.mkdir(exist_ok=True, parents=True) + + set_seed_everywhere(self.cfg.seed) + self.agent = FBcprAgent(**dataclasses.asdict(agent_cfg)) + + if self.cfg.use_wandb: + exp_name = "fbcpr" + wandb_name = exp_name + # fmt: off + wandb_config = dataclasses.asdict(self.cfg) + wandb.init(entity=self.cfg.wandb_ename, project=self.cfg.wandb_pname, + group=self.cfg.wandb_gname, name=wandb_name, # mode="disabled", + config=wandb_config) # type: ignore + # fmt: on + + with (self.work_dir / "config.json").open("w") as f: + json.dump(dataclasses.asdict(self.cfg), f, indent=4) + + self.manager = None + + def train(self): + self.start_time = time.time() + self.train_online() + + def train_online(self) -> None: + print("Loading expert trajectories") + expert_buffer = load_expert_trajectories(self.cfg.motions, self.cfg.motions_root, device=self.cfg.buffer_device) + + print("Creating the training environment") + train_env, mp_info = make_humenv( + num_envs=self.cfg.online_parallel_envs, + # vectorization_mode="sync", + wrappers=[ + gymnasium.wrappers.FlattenObservation, + lambda env: TimeAwareObservation(env, flatten=False), + ], + render_width=320, + render_height=320, + motions=self.cfg.motions, + motion_base_path=self.cfg.motions_root, + fall_prob=0.2, + state_init="MoCapAndFall", + ) + + print("Allocating buffers") + replay_buffer = { + "train": DictBuffer(capacity=self.cfg.buffer_size, device=self.cfg.buffer_device), + "expert_slicer": expert_buffer, + } + + print("Starting training") + progb = tqdm(total=self.cfg.num_env_steps) + td, info = train_env.reset() + done = np.zeros(self.cfg.online_parallel_envs, dtype=np.bool) + total_metrics, context = None, None + start_time = time.time() + fps_start_time = time.time() + for t in range(0, self.cfg.num_env_steps, self.cfg.online_parallel_envs): + if self.cfg.evaluate and t % self.cfg.eval_every_steps == 0: + eval_metrics = self.eval(t, replay_buffer=replay_buffer) + if self.cfg.prioritization: + # priorities + index_in_buffer = {} + for i, ep in enumerate(replay_buffer["expert_slicer"].storage): + index_in_buffer[ep["motion_id"][0].item()] = i + motions_id, priorities, idxs = [], [], [] + for _, metr in eval_metrics["tracking"].items(): + motions_id.append(metr["motion_id"]) + priorities.append(metr["emd"]) + idxs.append(index_in_buffer[metr["motion_id"]]) + priorities = ( + torch.clamp( + torch.tensor(priorities, dtype=torch.float32, device=self.agent.device), + min=self.cfg.prioritization_min_val, + max=self.cfg.prioritization_max_val, + ) + * self.cfg.prioritization_scale + ) + bins = torch.floor(priorities) + for i in range(int(bins.min().item()), int(bins.max().item()) + 1): + mask = bins == i + n = mask.sum().item() + if n > 0: + priorities[mask] = 1 / n + + if mp_info is not None: + mp_info["motion_buffer"].update_priorities(motions_id=motions_id, priorities=priorities.cpu().numpy()) + else: + train_env.unwrapped.motion_buffer.update_priorities(motions_id=motions_id, priorities=priorities.cpu().numpy()) + replay_buffer["expert_slicer"].update_priorities( + priorities=priorities.to(self.cfg.buffer_device), idxs=torch.tensor(np.array(idxs), device=self.cfg.buffer_device) + ) + + with torch.no_grad(): + obs = torch.tensor(td["obs"], dtype=torch.float32, device=self.agent.device) + step_count = torch.tensor(td["time"], device=self.agent.device) + context = self.agent.maybe_update_rollout_context(z=context, step_count=step_count) + if t < self.cfg.num_seed_steps: + action = train_env.action_space.sample().astype(np.float32) + else: + # this works in inference mode + action = self.agent.act(obs=obs, z=context, mean=False).cpu().detach().numpy() + new_td, reward, terminated, truncated, new_info = train_env.step(action) + real_next_obs = new_td["obs"].astype(np.float32).copy() + new_done = np.logical_or(terminated.ravel(), truncated.ravel()) + + if Version(gymnasium.__version__) >= Version("1.0"): + # We add only transitions corresponding to environments that have not reset in the previous step. + # For environments that have reset in the previous step, the new observation corresponds to the state after reset. + indexes = ~done + data = { + "observation": obs[indexes], + "action": action[indexes], + "z": context[indexes], + "step_count": step_count[indexes], + "qpos": info["qpos"][indexes], + "qvel": info["qvel"][indexes], + "next": { + "observation": real_next_obs[indexes], + "terminated": terminated[indexes].reshape(-1, 1), + "truncated": truncated[indexes].reshape(-1, 1), + "reward": reward[indexes].reshape(-1, 1), + "qpos": new_info["qpos"][indexes], + "qvel": new_info["qvel"][indexes], + }, + } + else: + raise NotImplementedError("still some work to do for gymnasium < 1.0") + replay_buffer["train"].extend(data) + + if len(replay_buffer["train"]) > 0 and t > self.cfg.num_seed_steps and t % self.cfg.update_agent_every == 0: + for _ in range(self.cfg.num_agent_updates): + metrics = self.agent.update(replay_buffer, t) + if total_metrics is None: + num_metrics_updates = 1 + total_metrics = {k: metrics[k].clone() for k in metrics.keys()} + else: + num_metrics_updates += 1 + total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()} + + if t % self.cfg.log_every_updates == 0 and total_metrics is not None: + m_dict = {} + for k in sorted(list(total_metrics.keys())): + tmp = total_metrics[k] / num_metrics_updates + m_dict[k] = np.round(tmp.mean().item(), 6) + m_dict["duration [minutes]"] = (time.time() - start_time) / 60 + m_dict["FPS"] = (1 if t == 0 else self.cfg.log_every_updates) / (time.time() - fps_start_time) + if self.cfg.use_wandb: + wandb.log( + {f"train/{k}": v for k, v in m_dict.items()}, + step=t, + ) + print(m_dict) + total_metrics = None + fps_start_time = time.time() + + if t % self.cfg.checkpoint_every_steps == 0: + self.agent.save(str(self.work_dir / "checkpoint")) + progb.update(self.cfg.online_parallel_envs) + td = new_td + done = new_done + info = new_info + self.agent.save(str(self.work_dir / "checkpoint")) + if mp_info is not None: + mp_info["manager"].shutdown() + + def eval(self, t, replay_buffer): + print(f"Starting evaluation at time {t}") + inference_function: str = "reward_wr_inference" + + self.agent._model.to("cpu") + self.agent._model.train(False) + + # --------------------------------------------------------------- + # Reward evaluation + # --------------------------------------------------------------- + eval_agent = RewardWrapper( + model=self.agent._model, + inference_dataset=replay_buffer["train"], + num_samples_per_inference=self.cfg.reward_eval_num_inference_samples, + inference_function=inference_function, + max_workers=1, + process_executor=False, + ) + reward_eval = RewardEvaluation( + tasks=self.cfg.reward_eval_tasks, + env_kwargs={"state_init": "Fall", "context": "spawn"}, + num_contexts=1, + num_envs=self.cfg.reward_eval_num_envs, + num_episodes=self.cfg.reward_eval_num_eval_episodes, + ) + start_t = time.time() + reward_metrics = {} + if not replay_buffer["train"].empty(): + print(f"Reward started at {time.ctime(start_t)}", flush=True) + reward_metrics = reward_eval.run(agent=eval_agent) + duration = time.time() - start_t + print(f"Reward eval time: {duration}") + if self.cfg.use_wandb: + m_dict = {} + avg_return = [] + for task in reward_metrics.keys(): + m_dict[f"{task}/return"] = np.mean(reward_metrics[task]["reward"]) + m_dict[f"{task}/return#std"] = np.std(reward_metrics[task]["reward"]) + avg_return.append(reward_metrics[task]["reward"]) + m_dict["reward/return"] = np.mean(avg_return) + m_dict["reward/return#std"] = np.std(avg_return) + m_dict["reward/time"] = duration + wandb.log( + {f"eval/reward/{k}": v for k, v in m_dict.items()}, + step=t, + ) + # --------------------------------------------------------------- + # Tracking evaluation + # --------------------------------------------------------------- + eval_agent = TrackingWrapper(model=self.agent._model) + tracking_eval = TrackingEvaluation( + motions=self.cfg.tracking_eval_motions, + motion_base_path=self.cfg.tracking_eval_motions_root, + env_kwargs={ + "state_init": "Default", + }, + num_envs=self.cfg.tracking_eval_num_envs, + ) + start_t = time.time() + print(f"Tracking started at {time.ctime(start_t)}", flush=True) + tracking_metrics = tracking_eval.run(agent=eval_agent) + duration = time.time() - start_t + print(f"Tracking eval time: {duration}") + if self.cfg.use_wandb: + aggregate, m_dict = collections.defaultdict(list), {} + for _, metr in tracking_metrics.items(): + for k, v in metr.items(): + if isinstance(v, numbers.Number): + aggregate[k].append(v) + for k, v in aggregate.items(): + m_dict[k] = np.mean(v) + m_dict[f"{k}#std"] = np.std(v) + m_dict["time"] = duration + + wandb.log( + {f"eval/tracking/{k}": v for k, v in m_dict.items()}, + step=t, + ) + # --------------------------------------------------------------- + # this is important, move back the agent to cuda and + # restart the training + self.agent._model.to("cuda") + self.agent._model.train() + + return {"reward": reward_metrics, "tracking": tracking_metrics} + + +if __name__ == "__main__": + config = tyro.cli(TrainConfig) + + env, _ = make_humenv( + num_envs=1, + vectorization_mode="sync", + wrappers=[gymnasium.wrappers.FlattenObservation], + render_width=320, + render_height=320, + ) + + agent_config = FBcprAgentConfig() + agent_config.model.obs_dim = env.observation_space.shape[0] + agent_config.model.action_dim = env.action_space.shape[0] + agent_config.model.norm_obs = True + agent_config.train.batch_size = 1024 + agent_config.train.use_mix_rollout = 1 + agent_config.train.update_z_every_step = 150 + agent_config.model.actor_std = 0.2 + agent_config.model.seq_length = 8 + # archi + # the config of the model trained in the paper + model, hidden_dim, hidden_layers = "simple", 1024, 2 + # uncomment the line below for the config of model deployed in the demo + # WARNING: you need to use compile=True on a A100 GPU or better, as otherwise training can be very slow + # model, hidden_dim, hidden_layers = "residual", 2048, 12 + agent_config.model.archi.z_dim = 256 + agent_config.model.archi.b.norm = 1 + agent_config.model.archi.norm_z = 1 + agent_config.model.archi.f.hidden_dim = hidden_dim + agent_config.model.archi.b.hidden_dim = 256 + agent_config.model.archi.actor.hidden_dim = hidden_dim + agent_config.model.archi.critic.hidden_dim = hidden_dim + agent_config.model.archi.f.hidden_layers = hidden_layers + agent_config.model.archi.b.hidden_layers = 1 + agent_config.model.archi.actor.hidden_layers = hidden_layers + agent_config.model.archi.critic.hidden_layers = hidden_layers + agent_config.model.archi.f.model = model + agent_config.model.archi.actor.model = model + agent_config.model.archi.critic.model = model + # optim + agent_config.train.lr_f = 1e-4 + agent_config.train.lr_b = 1e-5 + agent_config.train.lr_actor = 1e-4 + agent_config.train.lr_critic = 1e-4 + agent_config.train.ortho_coef = 100 + agent_config.train.train_goal_ratio = 0.2 + agent_config.train.expert_asm_ratio = 0.6 + agent_config.train.relabel_ratio = 0.8 + agent_config.train.reg_coeff = 0.01 + agent_config.train.q_loss_coef = 0.1 # or 0 + # discriminator cfg + agent_config.train.grad_penalty_discriminator = 10 + agent_config.train.weight_decay_discriminator = 0 + agent_config.train.lr_discriminator = 1e-5 + agent_config.model.archi.discriminator.hidden_layers = 3 + agent_config.model.archi.discriminator.hidden_dim = 1024 + agent_config.model.device = config.device + # misc + agent_config.train.discount = 0.98 + agent_config.compile = config.compile + agent_config.cudagraphs = config.cudagraphs + env.close() + + ws = Workspace(config, agent_cfg=agent_config) + ws.train() diff --git a/metamotivo/__init__.py b/metamotivo/__init__.py index 6878049..5ee8e44 100644 --- a/metamotivo/__init__.py +++ b/metamotivo/__init__.py @@ -44,4 +44,4 @@ def config_from_dict(source: Dict, config_class: Any): return target -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/metamotivo/buffers/buffers.py b/metamotivo/buffers/buffers.py index 7892bb9..e487c71 100644 --- a/metamotivo/buffers/buffers.py +++ b/metamotivo/buffers/buffers.py @@ -4,15 +4,16 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import torch + import dataclasses -from typing import Dict, Any, Union, List -from collections.abc import Mapping -from collections import defaultdict -import numpy as np -import numbers import functools +import numbers +from collections import defaultdict +from collections.abc import Mapping +from typing import Any, Dict, List, Union +import numpy as np +import torch Device = Union[str, torch.device] @@ -92,14 +93,14 @@ def sample(self, batch_size) -> Dict[str, torch.Tensor]: self.ind = torch.randint(0, len(self), (batch_size,)) return extract_values(self.storage, self.ind) - def get_full_buffer(self): + def get_full_buffer(self) -> Dict: if self._is_full: return self.storage else: return extract_values(self.storage, torch.arange(0, len(self))) -def extract_values(d, idxs): +def extract_values(d: Dict, idxs: List | torch.Tensor | np.ndarray) -> Dict: result = {} for k, v in d.items(): if isinstance(v, Mapping): @@ -190,12 +191,12 @@ def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]: return dict_cat(output) - def update_priorities(self, priorities, idxs): + def update_priorities(self, priorities: torch.Tensor, idxs: torch.Tensor) -> None: self.priorities[idxs] = priorities self.priorities = self.priorities / torch.sum(self.priorities) -def initialize_storage(data, storage, capacity, device) -> None: +def initialize_storage(data: Dict, storage: Dict, capacity: int, device: Device) -> None: def recursive_initialize(d, s): for k, v in d.items(): if isinstance(v, Mapping): @@ -234,10 +235,10 @@ def dtype_numpytotorch(np_dtype: Any) -> torch.dtype: raise ValueError(f"Unknown type {np_dtype}") -def dict_cat(d) -> Dict[str, torch.Tensor]: +def dict_cat(d: Mapping) -> Dict[str, torch.Tensor]: res = {} for k, v in d.items(): - if isinstance(v, dict): + if isinstance(v, Mapping): res[k] = dict_cat(v) else: res[k] = torch.cat(v, dim=0) diff --git a/metamotivo/fb/model.py b/metamotivo/fb/model.py index 5787256..878ef8b 100644 --- a/metamotivo/fb/model.py +++ b/metamotivo/fb/model.py @@ -10,6 +10,9 @@ from torch import nn import torch.nn.functional as F import copy +from pathlib import Path +from safetensors.torch import save_model as safetensors_save_model +import json from ..nn_models import build_backward, build_forward, build_actor, eval_mode from .. import config_from_dict, load_model @@ -93,6 +96,13 @@ def to(self, *args, **kwargs): def load(cls, path: str, device: str | None = None): return load_model(path, device, cls=cls) + def save(self, output_folder: str) -> None: + output_folder = Path(output_folder) + output_folder.mkdir(exist_ok=True) + safetensors_save_model(self, output_folder / "model.safetensors") + with (output_folder / "config.json").open("w+") as f: + json.dump(dataclasses.asdict(self.cfg), f, indent=4) + def _normalize(self, obs: torch.Tensor): with torch.no_grad(), eval_mode(self._obs_normalizer): return self._obs_normalizer(obs) diff --git a/metamotivo/fb_cpr/agent.py b/metamotivo/fb_cpr/agent.py index 287d56b..d76f16e 100644 --- a/metamotivo/fb_cpr/agent.py +++ b/metamotivo/fb_cpr/agent.py @@ -4,16 +4,17 @@ # LICENSE file in the root directory of this source tree. import dataclasses +from typing import Dict + import torch import torch.nn.functional as F -from typing import Dict +from torch import autograd -from ..fb.agent import TrainConfig as FBTrainConfig from ..fb.agent import FBAgent -from .model import FBcprModel, config_from_dict -from ..nn_models import eval_mode, _soft_update_params +from ..fb.agent import TrainConfig as FBTrainConfig +from ..nn_models import _soft_update_params, eval_mode from .model import Config as FBcprModelConfig -from torch import autograd +from .model import FBcprModel, config_from_dict @dataclasses.dataclass @@ -84,10 +85,10 @@ def setup_compile(self): self.update_critic = torch.compile(self.update_critic, mode=mode) self.update_discriminator = torch.compile(self.update_discriminator, mode=mode) self.encode_expert = torch.compile(self.encode_expert, mode=mode, fullgraph=True) - # self.sample_mixed_z = torch.compile(self.sample_mixed_z, mode=mode, fullgraph=True) if self.cfg.cudagraphs: from tensordict.nn import CudaGraphModule + self.update_critic = CudaGraphModule(self.update_critic, warmup=5) self.update_discriminator = CudaGraphModule(self.update_discriminator, warmup=5) self.encode_expert = CudaGraphModule(self.encode_expert, warmup=5) @@ -97,7 +98,11 @@ def sample_mixed_z(self, train_goal: torch.Tensor, expert_encodings: torch.Tenso z = self._model.sample_z(self.cfg.train.batch_size, device=self.device) p_goal = self.cfg.train.train_goal_ratio p_expert_asm = self.cfg.train.expert_asm_ratio - prob = torch.tensor([p_goal, p_expert_asm, 1 - p_goal - p_expert_asm], dtype=torch.float32, device=self.device) + prob = torch.tensor( + [p_goal, p_expert_asm, 1 - p_goal - p_expert_asm], + dtype=torch.float32, + device=self.device, + ) mix_idxs = torch.multinomial(prob, num_samples=self.cfg.train.batch_size, replacement=True).reshape(-1, 1) # zs obtained by encoding train goals @@ -117,11 +122,13 @@ def encode_expert(self, next_obs: torch.Tensor): # encode expert trajectories through B B_expert = self._model._backward_map(next_obs).detach() # batch x d B_expert = B_expert.view( - self.cfg.train.batch_size // self.cfg.model.seq_length, self.cfg.model.seq_length, B_expert.shape[-1] + self.cfg.train.batch_size // self.cfg.model.seq_length, + self.cfg.model.seq_length, + B_expert.shape[-1], ) # N x L x d z_expert = B_expert.mean(dim=1) # N x d z_expert = self._model.project_z(z_expert) - z_expert = z_expert.repeat_interleave(self.cfg.model.seq_length, dim=0) # batch x d + z_expert = torch.repeat_interleave(z_expert, self.cfg.model.seq_length, dim=0) # batch x d return z_expert def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: @@ -143,17 +150,27 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: self._model._obs_normalizer(train_next_obs) with torch.no_grad(), eval_mode(self._model._obs_normalizer): - train_obs, train_next_obs = self._model._obs_normalizer(train_obs), self._model._obs_normalizer(train_next_obs) - expert_obs, expert_next_obs = self._model._obs_normalizer(expert_obs), self._model._obs_normalizer(expert_next_obs) + train_obs, train_next_obs = ( + self._model._obs_normalizer(train_obs), + self._model._obs_normalizer(train_next_obs), + ) + expert_obs, expert_next_obs = ( + self._model._obs_normalizer(expert_obs), + self._model._obs_normalizer(expert_next_obs), + ) + torch.compiler.cudagraph_mark_step_begin() expert_z = self.encode_expert(next_obs=expert_next_obs) train_z = train_batch["z"].to(self.device) # train the discriminator grad_penalty = self.cfg.train.grad_penalty_discriminator if self.cfg.train.grad_penalty_discriminator > 0 else None - # TODO does it make sense to move cudagraph_mark_step_begin here? metrics = self.update_discriminator( - expert_obs=expert_obs, expert_z=expert_z, train_obs=train_obs, train_z=train_z, grad_penalty=grad_penalty + expert_obs=expert_obs, + expert_z=expert_z, + train_obs=train_obs, + train_z=train_z, + grad_penalty=grad_penalty, ) z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z).clone() @@ -166,7 +183,6 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: q_loss_coef = self.cfg.train.q_loss_coef if self.cfg.train.q_loss_coef > 0 else None clip_grad_norm = self.cfg.train.clip_grad_norm if self.cfg.train.clip_grad_norm > 0 else None - torch.compiler.cudagraph_mark_step_begin() metrics.update( self.update_fb( obs=train_obs, @@ -179,19 +195,50 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: clip_grad_norm=clip_grad_norm, ) ) - metrics.update(self.update_critic(obs=train_obs, action=train_action, discount=discount, next_obs=train_next_obs, z=train_z)) - metrics.update(self.update_actor(obs=train_obs, action=train_action, z=train_z, clip_grad_norm=clip_grad_norm)) + metrics.update( + self.update_critic( + obs=train_obs, + action=train_action, + discount=discount, + next_obs=train_next_obs, + z=train_z, + ) + ) + metrics.update( + self.update_actor( + obs=train_obs, + action=train_action, + z=train_z, + clip_grad_norm=clip_grad_norm, + ) + ) with torch.no_grad(): - _soft_update_params(self._forward_map_paramlist, self._target_forward_map_paramlist, self.cfg.train.fb_target_tau) - _soft_update_params(self._backward_map_paramlist, self._target_backward_map_paramlist, self.cfg.train.fb_target_tau) - _soft_update_params(self._critic_map_paramlist, self._target_critic_map_paramlist, self.cfg.train.critic_target_tau) + _soft_update_params( + self._forward_map_paramlist, + self._target_forward_map_paramlist, + self.cfg.train.fb_target_tau, + ) + _soft_update_params( + self._backward_map_paramlist, + self._target_backward_map_paramlist, + self.cfg.train.fb_target_tau, + ) + _soft_update_params( + self._critic_map_paramlist, + self._target_critic_map_paramlist, + self.cfg.train.critic_target_tau, + ) return metrics @torch.compiler.disable def gradient_penalty_wgan( - self, real_obs: torch.Tensor, real_z: torch.Tensor, fake_obs: torch.Tensor, fake_z: torch.Tensor + self, + real_obs: torch.Tensor, + real_z: torch.Tensor, + fake_obs: torch.Tensor, + fake_z: torch.Tensor, ) -> torch.Tensor: batch_size = real_obs.shape[0] alpha = torch.rand(batch_size, 1, device=real_obs.device) @@ -217,7 +264,12 @@ def gradient_penalty_wgan( return gradient_penalty def update_discriminator( - self, expert_obs: torch.Tensor, expert_z: torch.Tensor, train_obs: torch.Tensor, train_z: torch.Tensor, grad_penalty: float | None + self, + expert_obs: torch.Tensor, + expert_z: torch.Tensor, + train_obs: torch.Tensor, + train_z: torch.Tensor, + grad_penalty: float | None, ) -> Dict[str, torch.Tensor]: expert_logits = self._model._discriminator.compute_logits(obs=expert_obs, z=expert_z) unlabeled_logits = self._model._discriminator.compute_logits(obs=train_obs, z=train_z) @@ -284,7 +336,11 @@ def update_critic( return output_metrics def update_actor( - self, obs: torch.Tensor, action: torch.Tensor, z: torch.Tensor, clip_grad_norm: float | None + self, + obs: torch.Tensor, + action: torch.Tensor, + z: torch.Tensor, + clip_grad_norm: float | None, ) -> Dict[str, torch.Tensor]: dist = self._model._actor(obs, z, self._model.cfg.actor_std) action = dist.sample(clip=self.cfg.train.stddev_clip) diff --git a/metamotivo/wrappers/humenvbench.py b/metamotivo/wrappers/humenvbench.py index d0831f4..959d514 100644 --- a/metamotivo/wrappers/humenvbench.py +++ b/metamotivo/wrappers/humenvbench.py @@ -3,6 +3,7 @@ # This source code is licensed under the CC BY-NC 4.0 license found in the # LICENSE file in the root directory of this source tree. +import copy import torch from typing import Any import numpy as np @@ -58,6 +59,23 @@ def __getattr__(self, name): # Delegate to the wrapped instance return getattr(self.model, name) + def __deepcopy__(self, memo): + return type(self)(model=copy.deepcopy(self.model, memo), numpy_output=self.numpy_output, _dtype=copy.deepcopy(self._dtype)) + + def __getstate__(self): + # Return a dictionary containing the state of the object + return { + "model": self.model, + "numpy_output": self.numpy_output, + "_dtype": self._dtype, + } + + def __setstate__(self, state): + # Restore the state of the object from the given dictionary + self.model = state["model"] + self.numpy_output = state["numpy_output"] + self._dtype = state["_dtype"] + @dataclasses.dataclass(kw_only=True) class RewardWrapper(BaseHumEnvBenchWrapper): @@ -103,6 +121,46 @@ def reward_inference(self, task: str) -> torch.Tensor: ctxs = inference_fn(**td).reshape(1, -1) return ctxs + def __deepcopy__(self, memo): + # Create a new instance of the same type as self + return type(self)( + model=copy.deepcopy(self.model, memo), + numpy_output=self.numpy_output, + _dtype=copy.deepcopy(self._dtype), + inference_dataset=copy.deepcopy(self.inference_dataset), + num_samples_per_inference=self.num_samples_per_inference, + inference_function=self.inference_function, + max_workers=self.max_workers, + process_executor=self.process_executor, + process_context=self.process_context, + ) + + def __getstate__(self): + # Return a dictionary containing the state of the object + return { + "model": self.model, + "numpy_output": self.numpy_output, + "_dtype": self._dtype, + "inference_dataset": self.inference_dataset, + "num_samples_per_inference": self.num_samples_per_inference, + "inference_function": self.inference_function, + "max_workers": self.max_workers, + "process_executor": self.process_executor, + "process_context": self.process_context, + } + + def __setstate__(self, state): + # Restore the state of the object from the given dictionary + self.model = state["model"] + self.numpy_output = state["numpy_output"] + self._dtype = state["_dtype"] + self.inference_dataset = state["inference_dataset"] + self.num_samples_per_inference = state["num_samples_per_inference"] + self.inference_function = state["inference_function"] + self.max_workers = state["max_workers"] + self.process_executor = state["process_executor"] + self.process_context = state["process_context"] + @dataclasses.dataclass(kw_only=True) class GoalWrapper(BaseHumEnvBenchWrapper): diff --git a/tutorial_train.ipynb b/tutorial_train.ipynb deleted file mode 100644 index 0bd47b8..0000000 --- a/tutorial_train.ipynb +++ /dev/null @@ -1,338 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Meta Motivo: online training tutorial\n", - "This notebook is designed for showcasing how to use the library for training an FB-CPR agent. It is not designed to exactly reproduce the results in the paper." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "import torch\n", - "\n", - "torch.set_float32_matmul_precision(\"high\")\n", - "\n", - "import gymnasium\n", - "import numpy as np\n", - "import dataclasses\n", - "from humenv import make_humenv\n", - "from humenv.bench.gym_utils.rollouts import rollout\n", - "import mediapy as media\n", - "from metamotivo.buffers.buffers import DictBuffer, TrajectoryBuffer\n", - "from metamotivo.fb_cpr import FBcprAgent, FBcprAgentConfig\n", - "from tqdm.notebook import trange, tqdm\n", - "import time\n", - "from gymnasium import ObservationWrapper\n", - "\n", - "from packaging.version import Version\n", - "\n", - "if Version(gymnasium.__version__) >= Version(\"1.0\"):\n", - " raise RuntimeError(\"This tutorial does not support yet gymnasium >= 1.0\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We need to provide the time step inside an episode to the agent since it is used to decide when to switch policy (i.e., embedding `z`) in a rollout of the online training. Gymnasium >=1.0 provides this wrapper but here we report a simpler version for compatibility with previous versions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class TimeAwareObservation(ObservationWrapper):\n", - " \"\"\"\n", - " The MIT License\n", - "\n", - " Copyright (c) 2016 OpenAI\n", - " Copyright (c) 2022 Farama Foundation\n", - " \"\"\"\n", - " def __init__(self, env):\n", - " super().__init__(env)\n", - " self.max_timesteps = env.spec.max_episode_steps\n", - " self.timesteps: int = 0\n", - " self._time_preprocess_func = lambda time: np.array([time], dtype=np.int32)\n", - " time_space = gymnasium.spaces.Box(0, self.max_timesteps, dtype=np.int32)\n", - " assert not isinstance(\n", - " env.observation_space, (gymnasium.spaces.Dict, gymnasium.spaces.Tuple)\n", - " )\n", - "\n", - " observation_space = gymnasium.spaces.Dict(\n", - " obs=env.observation_space, time=time_space\n", - " )\n", - " self._append_data_func = lambda obs, time: {\"obs\": obs, \"time\": time}\n", - " self.observation_space = observation_space\n", - " self._obs_postprocess_func = lambda obs: obs\n", - "\n", - " def observation(self, observation):\n", - " return self._obs_postprocess_func(\n", - " self._append_data_func(\n", - " observation, self._time_preprocess_func(self.timesteps)\n", - " )\n", - " )\n", - "\n", - " def step(self, action):\n", - " self.timesteps += 1\n", - " return super().step(action)\n", - "\n", - " def reset(self, *, seed=None, options=None):\n", - " self.timesteps = 0\n", - " return super().reset(seed=seed, options=options)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Agent and Train parameters\n", - "\n", - "We start by defining the parameters of the FB-CPR agent." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env, _ = make_humenv(\n", - " num_envs=1,\n", - " vectorization_mode=\"sync\",\n", - " wrappers=[gymnasium.wrappers.FlattenObservation],\n", - " render_width=320,\n", - " render_height=320,\n", - ")\n", - "\n", - "agent_config = FBcprAgentConfig()\n", - "agent_config.model.obs_dim = env.observation_space.shape[0]\n", - "agent_config.model.action_dim = env.action_space.shape[0]\n", - "agent_config.model.device = \"cpu\"\n", - "agent_config.model.norm_obs = True\n", - "agent_config.model.seq_length = 1\n", - "# misc\n", - "agent_config.train.discount = 0.98\n", - "agent_config.compile = False\n", - "agent_config.cudagraphs = False\n", - "agent = FBcprAgent(**dataclasses.asdict(agent_config))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We also define a few parameters for online training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "buffer_size = 1_000_000\n", - "online_parallel_envs = 5\n", - "log_every_updates = 100\n", - "online_num_env_steps = 2000\n", - "num_seed_steps = 1000" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# \"Expert\" trajectories\n", - "FB-CPR leverages expert observation-only trajecteries in the training process. For training Meta Motivo you can use the motion capture dataset as described in the HumEnv repository. Here for simplicity we create \"expert\" trajectories running a random agent." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class RandomAgent:\n", - " def __init__(self, env):\n", - " self.env = env\n", - "\n", - " def act(self, *args, **kwargs):\n", - " return self.env.action_space.sample()\n", - "\n", - "\n", - "random_agent = RandomAgent(env)\n", - "_, episodes = rollout(env=env, agent=random_agent, num_episodes=4)\n", - "for ep in episodes:\n", - " ep[\"observation\"] = ep[\"observation\"].astype(np.float32)\n", - " del ep[\"action\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can visualize an episode by reloading `qpos` (and optionally `qvel`) information." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ep = episodes[0]\n", - "frames = []\n", - "for i in range(len(ep[\"info\"][\"qpos\"])):\n", - " env.unwrapped.set_physics(ep[\"info\"][\"qpos\"][i])\n", - " frames.append(env.render())\n", - "media.show_video(frames, fps=30)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With this tutorial we provide a simple buffer for storing trajectories (see `examples/trajecotory_buffer.py`)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "expert_buffer = TrajectoryBuffer(\n", - " capacity=len(episodes),\n", - " seq_length=agent_config.model.seq_length,\n", - " device=agent.device,\n", - ")\n", - "expert_buffer.extend(episodes)\n", - "print(expert_buffer)\n", - "env.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training loop\n", - "This section describes the training loop that should be self explanatory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_env, _ = make_humenv(\n", - " num_envs=online_parallel_envs,\n", - " vectorization_mode=\"sync\",\n", - " wrappers=[\n", - " gymnasium.wrappers.FlattenObservation,\n", - " lambda env: TimeAwareObservation(env),\n", - " ],\n", - " render_width=320,\n", - " render_height=320,\n", - ")\n", - "\n", - "replay_buffer = {\n", - " \"train\": DictBuffer(capacity=buffer_size, device=agent.device),\n", - " \"expert_slicer\": expert_buffer,\n", - "}\n", - "obs, _ = train_env.reset()\n", - "print(obs.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "progb = tqdm(total=online_num_env_steps)\n", - "td, info = train_env.reset()\n", - "total_metrics, context = None, None\n", - "start_time = time.time()\n", - "for t in range(0, online_num_env_steps, online_parallel_envs):\n", - " with torch.no_grad():\n", - " obs = torch.tensor(td[\"obs\"], dtype=torch.float32, device=agent.device)\n", - " step_count = torch.tensor(td[\"time\"], device=agent.device)\n", - " context = agent.maybe_update_rollout_context(z=context, step_count=step_count)\n", - " if t < num_seed_steps:\n", - " action = train_env.action_space.sample().astype(np.float32)\n", - " else:\n", - " # this works in inference mode\n", - " action = agent.act(obs=obs, z=context, mean=False).cpu().detach().numpy()\n", - " new_td, reward, terminated, truncated, new_info = train_env.step(action)\n", - " real_next_obs = new_td[\"obs\"].astype(np.float32).copy()\n", - " done = np.logical_or(terminated.ravel(), truncated.ravel())\n", - " for idx, trunc in enumerate(done):\n", - " if trunc:\n", - " print(new_info[\"final_observation\"])\n", - " real_next_obs[idx] = new_info[\"final_observation\"][idx][\"obs\"].astype(\n", - " np.float32\n", - " )\n", - " data = {\n", - " \"observation\": obs,\n", - " \"action\": action,\n", - " \"z\": context,\n", - " \"step_count\": step_count,\n", - " \"next\": {\n", - " \"observation\": real_next_obs,\n", - " \"terminated\": terminated.reshape(-1, 1),\n", - " \"truncated\": truncated.reshape(-1, 1),\n", - " \"reward\": reward.reshape(-1, 1),\n", - " },\n", - " }\n", - " replay_buffer[\"train\"].extend(data)\n", - "\n", - " metrics = agent.update(replay_buffer, t)\n", - "\n", - " if total_metrics is None:\n", - " total_metrics = {k: metrics[k] * 1 for k in metrics.keys()}\n", - " else:\n", - " total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()}\n", - " if t % log_every_updates == 0:\n", - " m_dict = {}\n", - " for k in sorted(list(total_metrics.keys())):\n", - " tmp = total_metrics[k] / (1 if t == 0 else log_every_updates)\n", - " m_dict[k] = np.round(tmp.mean().item(), 6)\n", - " m_dict[\"duration\"] = time.time() - start_time\n", - " print(f\"Steps: {t}\\n{m_dict}\")\n", - " total_metrics = None\n", - " progb.update(online_parallel_envs)\n", - " td = new_td" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}