This is an unofficial implemention of WoG based on open source framework and datasets.
Thanks to OpenVLA and CogACT for their awesome codebase.
conda create -n wog python=3.10Next, clone our repo and install the required packages:
git clone https://github.com/Selen-Suyue/WoG
cd WoG
pip install -e .
make setupFlash Attention is needed for training. You can simply run:
pip install -e .[train]or install it manually:
pip install packaging ninja
pip install "flash-attn==2.5.5" --no-build-isolationYou can start the trainging from the weights of OpenVLA for greater efficiency. Please follow the instruction of OpenVLA to download their weights:
mkdir -p pretrained
cd pretrained
git clone [email protected]:openvla/openvla-7b-prismatic
cd openvla-7b-prismatic
git lfs pullAlso, you can download the pretrained ckpts(.pth) of Vision Foundation Models in pretrained/vision. The Wan VAE is available at Wan VAE. The DINO and SigLIP weights can be downloaded via utils/download_weights.py.
The data of Open X-Embodiment (OXE) can be download following OXE and OpenVLA. Then launch the training script. For one node with 8 A100 GPUs as an example:
## For Stage I:
torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/train.py \
--pretrained_checkpoint pretrained/openvla-7b-prismatic/checkpoints/step-295000-epoch-40-loss=0.2200.pt \
--vla.type prism-dinosiglip-224px+oxe+diffusion \
--vla.data_mix oxe_magic_soup_plus \
--vla.expected_world_size 8 \
--vla.global_batch_size 256 \
--vla.per_device_batch_size 32 \
--vla.learning_rate 2e-5 \
--data_root_dir <path_to_dataset_dir> \
--run_root_dir <path_to_log/checkpoint_dir> \
--run_id <optional_run_id_for_wandb> \
--image_aug <True_or_False> \
--wandb_project <your_wandb_project> \
--wandb_entity <your_wandb_entity> \
--save_interval <num_of_steps_to_save_checkpoint> \
--repeated_diffusion_steps 8 \
--Ta 16 \
--action_model_type DiT-L \
--is_resume False \
--train_venc True
## For Stage II:
torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/train.py \
--pretrained_checkpoint <ckpt predtrained in stage I> \
--vla.type prism-dinosiglip-224px+oxe+diffusion \
--vla.data_mix oxe_magic_soup_plus \
--vla.expected_world_size 8 \
--vla.global_batch_size 256 \
--vla.per_device_batch_size 32 \
--vla.learning_rate 2e-5 \
--data_root_dir <path_to_dataset_dir> \
--run_root_dir <path_to_log/checkpoint_dir> \
--run_id <optional_run_id_for_wandb> \
--image_aug <True_or_False> \
--wandb_project <your_wandb_project> \
--wandb_entity <your_wandb_entity> \
--save_interval <num_of_steps_to_save_checkpoint> \
--repeated_diffusion_steps 8 \
--Ta 16 \
--action_model_type DiT-L \
--is_resume FalseWe provide pretrained checkpoints of the two stages in huggingface. WoG-V is the first-stage and WoG-A is the second.
In this section, we provide a minimal evaluation for our models in SIMPLER. First, please follow the instruction of SimplerEnv to install the simulation environment. Next, add our ./deploy to SimplerEnv/simpler_env/policies.
cp ./deploy <your_path_to_simpler>/simpler_env/policies -rThen add a new policy model in SimplerEnv/simpler_env/main_inference.py as below:
elif args.policy_model == "wog":
from simpler_env.policies.deploy import WoGSimInference
assert args.ckpt_path is not None
model = WoGSimInference(
saved_model_path=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
action_model_type='DiT-L',
)After that, you can modify and launch the scripts in deploy/scripts like:
cd <your_path_to_simpler>
bash simpler_env/policies/deploy/scripts/wog_put_in_drawer_visual_matching.shPs: For Calvin Deployment
We now support CALVIN deployment. You can convert the CALVIN dataset to RLDS format with calvin_rlds_builder. We've registered the calvin dataset in:
- prismatic/vla/datasets/rlds/oxe/transforms.py
- prismatic/vla/datasets/rlds/oxe/mixtures.py
- prismatic/vla/datasets/rlds/oxe/configs.py
for training. For deployment:
mv deploy/wog_policy_calvin.py your_path_to_calvin/calvin_models/calvin_agent/evaluation/ We provide a sample of RealWorld Deploy wrapper in deploy/wog_policy_real.py. It's supported by Maniunicon and you can also use it in other platforms.
ManiUnicon is a universal real-world robot control platform for data-collection and model deploy. You can refer the README to collect data in rlds format.
Once you have collected rlds dataset, modify the following files in our project:
- prismatic/vla/datasets/rlds/oxe/transforms.py
- prismatic/vla/datasets/rlds/oxe/mixtures.py
- prismatic/vla/datasets/rlds/oxe/configs.py
Then only stage-II is needed for training:
## For RealWorld:
torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/train.py \
--pretrained_checkpoint <ckpt predtrained in stage I> \
--vla.type prism-dinosiglip-224px+oxe+diffusion \
--vla.data_mix real_dataset_mix \
--vla.expected_world_size 8 \
--vla.global_batch_size 256 \
--vla.per_device_batch_size 32 \
--vla.learning_rate 2e-5 \
--data_root_dir <path_to_dataset_dir> \
--run_root_dir <path_to_log/checkpoint_dir> \
--run_id <optional_run_id_for_wandb> \
--image_aug <True_or_False> \
--wandb_project <your_wandb_project> \
--wandb_entity <your_wandb_entity> \
--save_interval <num_of_steps_to_save_checkpoint> \
--repeated_diffusion_steps 8 \
--Ta 16 \
--action_model_type DiT-L \
--is_resume FalseAfter training, You can follow wog config in maniunicon and wog class in maniunicon to deploy WoG.
@article{WoG,
title={World Guidance: World Modeling in Condition Space for Action Generation},
author={Yue Su and Sijin Chen and Haixin Shi and Mingyu Liu and Zhengshen Zhang and Ningyuan Huang and Weiheng Zhong and Zhengbang Zhu and Yuxiao Liu and Xihui Liu},
journal={arXiv preprint arXiv:2602.22010},
year={2026},
}All the code, model weights, and data are licensed under MIT license.
