Skip to content

Commit

Permalink
Fix wrappers and buffers (#7)
Browse files Browse the repository at this point in the history
## Fixes and logic changes

1. Add missing priorities in the trajectory buffer
2. Fix cloning of `z`s in mixed sampling of the fb and fb_cpr agents
3. Add load and save in the fb agent
4. Fixes in reward inference (detach tensors and move to numpy)
5. Improvement in relabeling (do not create multiprocessing version if
`max_workers==1`
6. Fix`get_full_buffer`

## Dependencies:

1. Specify `safetensors>=0.4.5`, before there was no requirement of
version
  • Loading branch information
MateuszGuzek authored Dec 20, 2024
1 parent b947ca6 commit 4e24a53
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 27 deletions.
33 changes: 23 additions & 10 deletions metamotivo/buffers/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,23 @@ def add_new_data(data, storage, expected_dim: int):
@torch.no_grad
def sample(self, batch_size) -> Dict[str, torch.Tensor]:
self.ind = torch.randint(0, len(self), (batch_size,))
return extract_values(self.storage, self.ind)

def get_full_buffer(self):
if self._is_full:
return self.storage
else:
return extract_values(self.storage, torch.arange(0, len(self)))

def extract_values(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = extract_values(v)
else:
result[k] = v[self.ind] + 0 # fast copy
return result

return extract_values(self.storage)
def extract_values(d, idxs):
result = {}
for k, v in d.items():
if isinstance(v, Mapping):
result[k] = extract_values(v, idxs)
else:
result[k] = v[idxs]
return result


@dataclasses.dataclass
Expand All @@ -115,6 +121,7 @@ def __post_init__(self) -> None:
self._is_full = False
self.storage = None
self._idx = 0
self.priorities = None

def __len__(self) -> int:
return self.capacity if self._is_full else self._idx
Expand All @@ -132,6 +139,7 @@ def extend(self, data: List[dict]) -> None:
self.storage.append(element)
self._idx = 0
self._is_full = False
self.priorities = torch.ones(self.capacity, device=self.device, dtype=torch.float32) / self.capacity

def add(new_data):
storage = {}
Expand Down Expand Up @@ -164,7 +172,8 @@ def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]:
)
num_slices = batch_size // self.seq_length

self.ep_ind = torch.randint(0, len(self), (num_slices,))
# self.ep_ind = torch.randint(0, len(self), (num_slices,))
self.ep_ind = torch.multinomial(self.priorities, num_slices, replacement=True)
output = defaultdict(list)
offset = 0
if len(self.output_key_tp1) > 0:
Expand All @@ -181,6 +190,10 @@ def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]:

return dict_cat(output)

def update_priorities(self, priorities, idxs):
self.priorities[idxs] = priorities
self.priorities = self.priorities / torch.sum(self.priorities)


def initialize_storage(data, storage, capacity, device) -> None:
def recursive_initialize(d, s):
Expand Down
42 changes: 40 additions & 2 deletions metamotivo/fb/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from .model import Config as FBModelConfig
from ..nn_models import weight_init, _soft_update_params, eval_mode
from ..misc.zbuffer import ZBuffer
from pathlib import Path
import json
import safetensors


@dataclasses.dataclass
Expand Down Expand Up @@ -142,7 +145,7 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]:
obs, next_obs = self._model._obs_normalizer(obs), self._model._obs_normalizer(next_obs)

torch.compiler.cudagraph_mark_step_begin()
z = self.sample_mixed_z(train_goal=next_obs) + 0 # fast copy
z = self.sample_mixed_z(train_goal=next_obs).clone()
self.z_buffer.add(z)

q_loss_coef = self.cfg.train.q_loss_coef if self.cfg.train.q_loss_coef > 0 else None
Expand Down Expand Up @@ -210,7 +213,7 @@ def update_fb(
orth_loss = orth_loss_offdiag + orth_loss_diag
fb_loss += self.cfg.train.ortho_coef * orth_loss

q_loss = torch.tensor([0.0], device=z.device)
q_loss = torch.zeros(1, device=z.device, dtype=z.dtype)
if q_loss_coef is not None:
with torch.no_grad():
next_Qs = (target_Fs * z).sum(dim=-1) # num_parallel x batch
Expand Down Expand Up @@ -307,3 +310,38 @@ def maybe_update_rollout_context(self, z: torch.Tensor | None, step_count: torch
else:
z = self._model.sample_z(step_count.shape[0], device=self.cfg.model.device)
return z

@classmethod
def load(cls, path: str, device: str | None = None):
path = Path(path)
with (path / "config.json").open() as f:
loaded_config = json.load(f)
if device is not None:
loaded_config["model"]["device"] = device
agent = cls(**loaded_config)
optimizers = torch.load(str(path / "optimizers.pth"), weights_only=True)
agent.actor_optimizer.load_state_dict(optimizers["actor_optimizer"])
agent.backward_optimizer.load_state_dict(optimizers["backward_optimizer"])
agent.forward_optimizer.load_state_dict(optimizers["forward_optimizer"])

safetensors.torch.load_model(agent._model, path / "model/model.safetensors", device=device)
return agent

def save(self, output_folder: str) -> None:
output_folder = Path(output_folder)
output_folder.mkdir(exist_ok=True)
with (output_folder / "config.json").open("w+") as f:
json.dump(dataclasses.asdict(self.cfg), f, indent=4)
# save optimizer
torch.save(
{
"actor_optimizer": self.actor_optimizer.state_dict(),
"backward_optimizer": self.backward_optimizer.state_dict(),
"forward_optimizer": self.forward_optimizer.state_dict(),
},
output_folder / "optimizers.pth",
)
# save model
model_folder = output_folder / "model"
model_folder.mkdir(exist_ok=True)
self._model.save(output_folder=str(model_folder))
2 changes: 1 addition & 1 deletion metamotivo/fb_cpr/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]:
expert_obs=expert_obs, expert_z=expert_z, train_obs=train_obs, train_z=train_z, grad_penalty=grad_penalty
)

z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z) + 0 # fast copy
z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z).clone()
self.z_buffer.add(z)

if self.cfg.train.relabel_ratio is not None:
Expand Down
33 changes: 20 additions & 13 deletions metamotivo/wrappers/humenvbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def get_next(field: str, data: Any):
if ("next", field) in data:
if "next" in data and field in data["next"]:
return data["next"][field]
elif f"next_{field}" in data:
return data[f"next_{field}"]
Expand Down Expand Up @@ -77,6 +77,10 @@ def reward_inference(self, task: str) -> torch.Tensor:
qpos = get_next("qpos", data)
qvel = get_next("qvel", data)
action = data["action"]
if isinstance(qpos, torch.Tensor):
qpos = qpos.cpu().detach().numpy()
qvel = qvel.cpu().detach().numpy()
action = action.cpu().detach().numpy()
rewards = relabel(
env,
qpos,
Expand Down Expand Up @@ -154,19 +158,22 @@ def relabel(
):
chunk_size = int(np.ceil(qpos.shape[0] / max_workers))
args = [(qpos[i : i + chunk_size], qvel[i : i + chunk_size], action[i : i + chunk_size]) for i in range(0, qpos.shape[0], chunk_size)]
if process_executor:
import multiprocessing

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context(process_context),
) as exe:
f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
result = exe.map(f, args)
if max_workers == 1:
result = [_relabel_worker(args[0], model=env.unwrapped.model, reward_fn=reward_fn)]
else:
with ThreadPoolExecutor(max_workers=max_workers) as exe:
f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
result = exe.map(f, args)
if process_executor:
import multiprocessing

with ProcessPoolExecutor(
max_workers=max_workers,
mp_context=multiprocessing.get_context(process_context),
) as exe:
f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
result = exe.map(f, args)
else:
with ThreadPoolExecutor(max_workers=max_workers) as exe:
f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
result = exe.map(f, args)

tmp = [r for r in result]
return np.concatenate(tmp)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ description = "Inference and Training of FB-CPR"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"safetensors>=0.4.5",
"torch>=2.3",
"safetensors",
]

[project.urls]
Expand Down

0 comments on commit 4e24a53

Please sign in to comment.