⚡ Accelerate Action Diffusion by 5× via Dynamic Test-time Pruning and Omnidirectional Feature Reusing
CVPR 2026
Test-time Sparsity (TTS) accelerates action diffusion by dynamically predicting prunable residual computations for each model forward at test time. Our method reduces FLOPs by 92% and achieves 5× wall-clock speedup, reaching an inference frequency of 47.5 Hz on an NVIDIA 4090 GPU without performance degradation.
- Dynamic Test-time Pruning: A lightweight pruner that shares the encoder with the diffusion transformer, dynamically predicting skippable residual computations before each forward pass
- Omnidirectional Feature Reusing: Achieves 95% sparsity by selectively reusing features cached from the current forward, previous denoising timesteps, and earlier rollout iterations
- Highly Parallelized Pipeline: Decouples encoding and pruning from the autoregressive denoising loop, reducing non-decoder delay to milliseconds via parallel processing and asynchronous execution
git clone https://github.com/ky-ji/Test-time-Sparsity.git
cd Test-time-Sparsity
git submodule update --init --recursive
conda env create -f conda_environment.yaml
conda activate robodiff
pip install -e .This section reproduces the main results from Tables 1–3 of the paper using the TTSInfer module.
Follow the Diffusion Policy data download instructions to download the robosuite / FurnitureBench datasets. By default the code expects data under a data/ directory.
Supported tasks: can_ph, can_mh, lift_ph, lift_mh, square_ph, square_mh, transport, tool_hang, kitchen.
cd diffusion_policy
python train.py --config-name=train_diffusion_transformer_hybrid_workspace \
task=<task_name>Or download pre-trained checkpoints from the Diffusion Policy project page.
Run the trained baseline policy and collect successful rollout trajectories:
python -m TTSInfer.scripts.collect_trajectory_data \
--checkpoint /path/to/diffusion_policy.ckpt \
--output_dir pruner_tra_data_max/trajectories/<task_name> \
--num_episodes 100 \
--device cuda:0Only successful episodes are saved. Repeat for each task. Expected output structure:
pruner_tra_data_max/trajectories/<task_name>/
├── dataset_summary.json
└── episodes/
├── episode_000000/
│ ├── metadata.json
│ ├── trajectory.pt
│ └── videos/
└── ...
python -m TTSInfer.scripts.train_eval.train_pruner \
--task_name can_ph \
--device cuda:0 \
--output_dir sim_result \
--config TTSInfer/pruner_config/training_config.yaml \
--datatype max \
--train_version 0Output: sim_result/pruner_ckpt/<timestamp>/<train_id>/<task_name>/pruner_model_<epoch>_<loss>.pt
python -m TTSInfer.scripts.train_eval.eval_pruner \
--output_dir sim_result/pruner_ckpt \
--timestamp <train_timestamp> \
--task_name can_ph \
--train_id 0 \
--epoch <best_epoch> \
--device cuda:0python -m TTSInfer.scripts.exp.eval_speed_only \
-t can_ph \
-e <epoch> \
--timestamp <train_timestamp> \
--train_root <path_to_sim_result/pruner_ckpt/<timestamp>/<train_id>/<task_name>> \
--device cuda:0The realworld-TTS/ module provides a complete pipeline for deploying TTS on real robots.
cd realworld-TTS
ZARR_DATASET=/path/to/your_dataset.zarr \
TRAJECTORY_DIR=/path/to/output_trajectory \
CHECKPOINT=/path/to/dp_policy.ckpt \
bash train/convert_data.shCopy and adapt a training config for your task:
cp train/configs/train_pruner_config_pick.yaml train/configs/my_task.yamlcd realworld-TTS
CONFIG=train/configs/my_task.yaml \
CHECKPOINT=/path/to/dp_policy.ckpt \
TRAJECTORY_DIR=/path/to/output_trajectory \
OUTPUT_DIR=./output/pruner_tts \
bash train/train_pruner.shMulti-GPU training is supported automatically when NUM_GPUS > 1.
cd realworld-TTS
python tts_accelerator/scripts/run_dp_with_tts.py \
--config tts_accelerator/configs/assembly_bun.yamlEdit realworld-TTS/server/configs/server_config_local.py to set checkpoint paths, device, scheduler type, and pruning ratio.
This project is licensed under the MIT License - see the LICENSE file for details.

