-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathplay_policy.py
More file actions
89 lines (67 loc) · 2.51 KB
/
play_policy.py
File metadata and controls
89 lines (67 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"
import logging as python_logging
LOGGER = python_logging.getLogger()
LOGGER.setLevel(python_logging.INFO)
from absl import logging
logging.set_verbosity(logging.INFO)
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import tyro
from tqdm import tqdm
import torch
from src.envs.g1.g1_tracking_env import default_config
from src.envs.g1.play_g1_tracking_env import PlayG1TrackingEnv
@dataclass
class Args:
exp_name: str
play_ref_motion: bool = False
use_viewer: bool = False # passive viewer (with display)
use_renderer: bool = False # renderer with video (headless mode)
@dataclass
class State:
info: dict
obs: dict
def get_latest_ckpt(path: Path) -> Path | None:
ckpts = [ckpt for ckpt in path.glob("*") if not ckpt.name.endswith(".json")]
ckpts.sort(key=lambda x: int(x.name))
return ckpts[-1] if ckpts else None
def play(args: Args):
task_cfg = default_config()
env_cfg = task_cfg.env_config
config_path = Path(__file__).parent / "experiments" / args.exp_name / "checkpoints" / "config.json"
with open(config_path, "r") as f:
config = json.load(f)
del config["env_config"]["reference_traj_config"]
env_cfg.update(config["env_config"])
# env_cfg.reference_traj_config.name = {"lafan1": ["dance1_subject2"]}
env = PlayG1TrackingEnv(
terrain_type=env_cfg.terrain_type,
config=env_cfg,
play_ref_motion=args.play_ref_motion,
use_viewer=args.use_viewer,
use_renderer=args.use_renderer,
exp_name=args.exp_name,
)
ckpt_path = Path(__file__).parent / "experiments" / args.exp_name / "checkpoints"
latest_ckpt = get_latest_ckpt(ckpt_path)
if latest_ckpt is None:
raise FileNotFoundError("No checkpoint found.")
policy_path = latest_ckpt / "policy.pt"
policy_jit = torch.jit.load(policy_path, map_location="cpu")
state = env.reset()
len_traj = env.th.traj.data.qpos.shape[0] - len(env_cfg.reference_traj_config.name) - 1
for i in tqdm(range(len_traj)):
with torch.no_grad():
action = policy_jit(torch.from_numpy(state.obs["state"].reshape(1, -1).astype(np.float32))).cpu().numpy()
state = env.step(state, action)
env.close()
if __name__ == "__main__":
args = tyro.cli(Args)
play(args)