A light weight adaptation of Diffusion Policy used in Real2Render2Real.
DISCLAIMER: A lot of the tangential code were taken out from the original repo to keep the code light weight. There may be quite a few issues / bugs with the code. Please report them if you find any.
# we use python=3.10.15
conda create -n tinydp python=3.10.15
conda activate tinydp
pip install -e .To optimize for training throughput, the h5 files generated in the real2render2real repo are converted to zarr files and jpg images:
python script/convert_dataset_to_mp4.py --root-dir /PATH/TO/DATASET We provide example training scripts below:
EPOCHS=200
DATASET_ROOT=/PATH/TO/DATASET
LOG_NAME=YOUR_LOG_NAME
OUTPUT_DIR=./dpgs_checkpoints
python script/train.py --dataset-cfg.dataset-root $DATASET_ROOT --logging-cfg.log-name $LOG_NAME --logging-cfg.output-dir $OUTPUT_DIR --trainer-cfg.epochs $EPOCHSThe code can be run on multi-GPU via DDP
MASTERPORT=2222
CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 --master_port=$MASTERPORT script/train.py --dataset-cfg.dataset-root $DATASET_ROOT --logging-cfg.log-name $LOG_NAME --logging-cfg.output-dir $OUTPUT_DIR --trainer-cfg.epochs $EPOCHSPlease see other knobs you can tune via
python script/train.py --helpIf you are using sim data, please add the following flag:
--dataset-cfg.is-sim-dataIf only one arm in a bimanual setup is used, please add ONE of the following flags:
--model-cfg.policy-cfg.pred-left-only
--model-cfg.policy-cfg.pred-right-onlyIf you want to subsample the data, you can add the following flag:
SUBSAMPLE_NUM_TRAJ=100
--dataset-cfg.data-subsample-num-traj $SUBSAMPLE_NUM_TRAJWe provide an example snippet below:
from tinydp.policy.diffusion_wrapper import DiffusionWrapper
model_ckpt_folder = "/PATH/TO/MODEL/CHECKPOINT"
ckpt_id = 50
device = "cuda"
inferencer = DiffusionWrapper(model_ckpt_folder, ckpt_id, device=device)
while True:
# nbatch: Batch dictionary containing:
# - observation: torch.tensor: Images of shape (B, T, num_cameras, C, H, W)
# - proprio: torch.tensor: Proprioceptive data of shape (B, T, D)
nbatch = {
"observation": ...,
"proprio": ...,
}
nbatch = {k: v.to(device) for k, v in nbatch.items()}
pred_action = inferencer(nbatch) # batch, action_horizon, action_dim (it is denormalized with the statistics)A more complete example can be found in script/inference.py.
A lot of the code are borrowed from Diffusion Policy.