diff --git a/metamotivo/buffers/buffers.py b/metamotivo/buffers/buffers.py index 78ee6f7..7892bb9 100644 --- a/metamotivo/buffers/buffers.py +++ b/metamotivo/buffers/buffers.py @@ -90,17 +90,23 @@ def add_new_data(data, storage, expected_dim: int): @torch.no_grad def sample(self, batch_size) -> Dict[str, torch.Tensor]: self.ind = torch.randint(0, len(self), (batch_size,)) + return extract_values(self.storage, self.ind) + + def get_full_buffer(self): + if self._is_full: + return self.storage + else: + return extract_values(self.storage, torch.arange(0, len(self))) - def extract_values(d): - result = {} - for k, v in d.items(): - if isinstance(v, dict): - result[k] = extract_values(v) - else: - result[k] = v[self.ind] + 0 # fast copy - return result - return extract_values(self.storage) +def extract_values(d, idxs): + result = {} + for k, v in d.items(): + if isinstance(v, Mapping): + result[k] = extract_values(v, idxs) + else: + result[k] = v[idxs] + return result @dataclasses.dataclass @@ -115,6 +121,7 @@ def __post_init__(self) -> None: self._is_full = False self.storage = None self._idx = 0 + self.priorities = None def __len__(self) -> int: return self.capacity if self._is_full else self._idx @@ -132,6 +139,7 @@ def extend(self, data: List[dict]) -> None: self.storage.append(element) self._idx = 0 self._is_full = False + self.priorities = torch.ones(self.capacity, device=self.device, dtype=torch.float32) / self.capacity def add(new_data): storage = {} @@ -164,7 +172,8 @@ def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]: ) num_slices = batch_size // self.seq_length - self.ep_ind = torch.randint(0, len(self), (num_slices,)) + # self.ep_ind = torch.randint(0, len(self), (num_slices,)) + self.ep_ind = torch.multinomial(self.priorities, num_slices, replacement=True) output = defaultdict(list) offset = 0 if len(self.output_key_tp1) > 0: @@ -181,6 +190,10 @@ def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]: return dict_cat(output) + def update_priorities(self, priorities, idxs): + self.priorities[idxs] = priorities + self.priorities = self.priorities / torch.sum(self.priorities) + def initialize_storage(data, storage, capacity, device) -> None: def recursive_initialize(d, s): diff --git a/metamotivo/fb/agent.py b/metamotivo/fb/agent.py index 4561328..8c87616 100644 --- a/metamotivo/fb/agent.py +++ b/metamotivo/fb/agent.py @@ -12,6 +12,9 @@ from .model import Config as FBModelConfig from ..nn_models import weight_init, _soft_update_params, eval_mode from ..misc.zbuffer import ZBuffer +from pathlib import Path +import json +import safetensors @dataclasses.dataclass @@ -142,7 +145,7 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: obs, next_obs = self._model._obs_normalizer(obs), self._model._obs_normalizer(next_obs) torch.compiler.cudagraph_mark_step_begin() - z = self.sample_mixed_z(train_goal=next_obs) + 0 # fast copy + z = self.sample_mixed_z(train_goal=next_obs).clone() self.z_buffer.add(z) q_loss_coef = self.cfg.train.q_loss_coef if self.cfg.train.q_loss_coef > 0 else None @@ -210,7 +213,7 @@ def update_fb( orth_loss = orth_loss_offdiag + orth_loss_diag fb_loss += self.cfg.train.ortho_coef * orth_loss - q_loss = torch.tensor([0.0], device=z.device) + q_loss = torch.zeros(1, device=z.device, dtype=z.dtype) if q_loss_coef is not None: with torch.no_grad(): next_Qs = (target_Fs * z).sum(dim=-1) # num_parallel x batch @@ -307,3 +310,38 @@ def maybe_update_rollout_context(self, z: torch.Tensor | None, step_count: torch else: z = self._model.sample_z(step_count.shape[0], device=self.cfg.model.device) return z + + @classmethod + def load(cls, path: str, device: str | None = None): + path = Path(path) + with (path / "config.json").open() as f: + loaded_config = json.load(f) + if device is not None: + loaded_config["model"]["device"] = device + agent = cls(**loaded_config) + optimizers = torch.load(str(path / "optimizers.pth"), weights_only=True) + agent.actor_optimizer.load_state_dict(optimizers["actor_optimizer"]) + agent.backward_optimizer.load_state_dict(optimizers["backward_optimizer"]) + agent.forward_optimizer.load_state_dict(optimizers["forward_optimizer"]) + + safetensors.torch.load_model(agent._model, path / "model/model.safetensors", device=device) + return agent + + def save(self, output_folder: str) -> None: + output_folder = Path(output_folder) + output_folder.mkdir(exist_ok=True) + with (output_folder / "config.json").open("w+") as f: + json.dump(dataclasses.asdict(self.cfg), f, indent=4) + # save optimizer + torch.save( + { + "actor_optimizer": self.actor_optimizer.state_dict(), + "backward_optimizer": self.backward_optimizer.state_dict(), + "forward_optimizer": self.forward_optimizer.state_dict(), + }, + output_folder / "optimizers.pth", + ) + # save model + model_folder = output_folder / "model" + model_folder.mkdir(exist_ok=True) + self._model.save(output_folder=str(model_folder)) diff --git a/metamotivo/fb_cpr/agent.py b/metamotivo/fb_cpr/agent.py index 87f4b82..287d56b 100644 --- a/metamotivo/fb_cpr/agent.py +++ b/metamotivo/fb_cpr/agent.py @@ -156,7 +156,7 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: expert_obs=expert_obs, expert_z=expert_z, train_obs=train_obs, train_z=train_z, grad_penalty=grad_penalty ) - z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z) + 0 # fast copy + z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z).clone() self.z_buffer.add(z) if self.cfg.train.relabel_ratio is not None: diff --git a/metamotivo/wrappers/humenvbench.py b/metamotivo/wrappers/humenvbench.py index 77230f8..d0831f4 100644 --- a/metamotivo/wrappers/humenvbench.py +++ b/metamotivo/wrappers/humenvbench.py @@ -15,7 +15,7 @@ def get_next(field: str, data: Any): - if ("next", field) in data: + if "next" in data and field in data["next"]: return data["next"][field] elif f"next_{field}" in data: return data[f"next_{field}"] @@ -77,6 +77,10 @@ def reward_inference(self, task: str) -> torch.Tensor: qpos = get_next("qpos", data) qvel = get_next("qvel", data) action = data["action"] + if isinstance(qpos, torch.Tensor): + qpos = qpos.cpu().detach().numpy() + qvel = qvel.cpu().detach().numpy() + action = action.cpu().detach().numpy() rewards = relabel( env, qpos, @@ -154,19 +158,22 @@ def relabel( ): chunk_size = int(np.ceil(qpos.shape[0] / max_workers)) args = [(qpos[i : i + chunk_size], qvel[i : i + chunk_size], action[i : i + chunk_size]) for i in range(0, qpos.shape[0], chunk_size)] - if process_executor: - import multiprocessing - - with ProcessPoolExecutor( - max_workers=max_workers, - mp_context=multiprocessing.get_context(process_context), - ) as exe: - f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) - result = exe.map(f, args) + if max_workers == 1: + result = [_relabel_worker(args[0], model=env.unwrapped.model, reward_fn=reward_fn)] else: - with ThreadPoolExecutor(max_workers=max_workers) as exe: - f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) - result = exe.map(f, args) + if process_executor: + import multiprocessing + + with ProcessPoolExecutor( + max_workers=max_workers, + mp_context=multiprocessing.get_context(process_context), + ) as exe: + f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) + result = exe.map(f, args) + else: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) + result = exe.map(f, args) tmp = [r for r in result] return np.concatenate(tmp) diff --git a/pyproject.toml b/pyproject.toml index c7ea176..c702cb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,8 @@ description = "Inference and Training of FB-CPR" readme = "README.md" requires-python = ">=3.10" dependencies = [ + "safetensors>=0.4.5", "torch>=2.3", - "safetensors", ] [project.urls]