Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,33 @@ pip install git+https://github.com/juelg/agents.git

For more details, see the [OpenVLA github page](https://github.com/openvla/openvla).

### OpenPi / Pi0
To use OpenPi, create a new conda environment:
```shell
conda create -n openpi python=3.11 -y
conda activate openpi
```
Clone the repo and install it.
```shell
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
# Or if you already cloned the repo:
git submodule update --init --recursive
# install dependencies
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
```
For more details see [openpi's github](https://github.com/Physical-Intelligence/openpi).


## Usage
To start an agents server use the `start-server` command where `kwargs` is a dictionary of the constructor arguments of the policy you want to start e.g.
```shell
# octo
python -m agents start-server octo --host localhost --port 8080 --kwargs '{"checkpoint_path": "hf://Juelg/octo-base-1.5-finetuned-maniskill", "checkpoint_step": None, "horizon": 1, "unnorm_key": []}'
# openvla
python -m agents start-server openvla --host localhost --port 8080 --kwargs '{"checkpoint_path": "Juelg/openvla-7b-finetuned-maniskill", "device": "cuda:0", "attn_implementation": "flash_attention_2", "unnorm_key": "maniskill_human:7.0.0", "checkpoint_step": 40000}'
# openpi
python -m agents start-server openpi --port=8080 --host=localhost --kwargs='{"checkpoint_path": "<path to checkpoint>/{checkpoint_step}", "train_config_name": "pi0_rcs", "checkpoint_step": <checkpoint_step>}' # leave "{checkpoint_step}" it will be replaced, "train_config_name" is the key for the training config
```

There is also the `run-eval-during-training` command to evaluate a model during training, so a single checkpoint.
Expand Down
61 changes: 61 additions & 0 deletions src/agents/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,66 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
return info


class OpenPiModel(Agent):

def __init__(
self,
train_config_name: str = "pi0_droid",
default_checkpoint_path: str = "gs://openpi-assets/checkpoints/pi0_droid",
execution_horizon=20,
**kwargs,
) -> None:
super().__init__(default_checkpoint_path=default_checkpoint_path, **kwargs)
from openpi.training import config

logging.info(f"checkpoint_path: {self.checkpoint_path}, checkpoint_step: {self.checkpoint_step}")
self.openpi_path = self.checkpoint_path.format(checkpoint_step=self.checkpoint_step)

self.cfg = config.get_config(train_config_name)
self.execution_horizon = execution_horizon

self.chunk_counter = self.execution_horizon
self._cached_action_chunk = None

def initialize(self):
from openpi.policies import policy_config
from openpi.shared import download

checkpoint_dir = download.maybe_download(self.openpi_path)

# Create a trained policy.
self.policy = policy_config.create_trained_policy(self.cfg, checkpoint_dir)

def act(self, obs: Obs) -> Act:
if self.chunk_counter < self.execution_horizon:
self.chunk_counter += 1
return Act(action=self._cached_action_chunk[self.chunk_counter])

else:
self.chunk_counter = 0
observation = {f"observation/{k}": np.copy(v).transpose(2, 0, 1) for k, v in obs.cameras.items()}
observation.update(
{
# openpi expects 0 as gripper open and 1 as closed
"observation/state": np.concatenate([obs.info["joints"], [1 - obs.gripper]]),
"prompt": self.instruction,
}
)
action_chunk = self.policy.infer(observation)["actions"]

# convert gripper action into agents format
action_chunk[:, -1] = 1 - action_chunk[:, -1]
self._cached_action_chunk = action_chunk

return Act(action=action_chunk[0])

def reset(self, obs: Obs, instruction: Any):
super().reset(obs, instruction)
self.chunk_counter = self.execution_horizon
self._cached_action_chunk = None
return {}


class OpenVLAModel(Agent):
# === Utilities ===
SYSTEM_PROMPT = (
Expand Down Expand Up @@ -457,4 +517,5 @@ def act(self, obs: Obs) -> Act:
openvla=OpenVLAModel,
octodist=OctoActionDistribution,
openvladist=OpenVLADistribution,
openpi=OpenPiModel,
)