Official Implementation of "Rejection Mixing: Fast Semantic Propagation of Mask Tokens for Efficient DLLM Inference"
This repository offers the codebase for ReMix, including comprehensive scripts for reproducing demos and evaluations on LLaDA and MMaDA.
We recommend using uv for dependency and virtual environment management.
For LLaDA:
pip install uv
cd LLaDA
uv venv --python 3.11 dev
source dev/bin/activate
uv pip install -r requirements.txtFor MMaDA:
cd MMaDA
pip install uv
uv venv --python 3.11 dev
source dev/bin/activate
uv pip install -r requirements.txt
### For lmms-eval
cd lmms_eval
uv pip install -e .- Prepare Model and Datasets
Before running inference or evaluation, please download the following models and datasets from Hugging Face into the specified local directories (e.g., ./LLaDA/models/ and ./LLaDA/data/).
You may use either huggingface-cli or the Python datasets library to complete the download.
| Model Name | Hugging Face Repo | Local Path |
|---|---|---|
| LLaDA-8B-Instruct | GSAI-ML/LLaDA-8B-Instruct | ./LLaDA/models/LLaDA-8B-Instruct/ |
| Dataset Name | Hugging Face Repo | Local Path |
|---|---|---|
| GSM8K | openai/gsm8k | ./LLaDA/data/gsm8k/ |
| MATH-500 | HuggingFaceH4/MATH-500 | ./LLaDA/data/math500/ |
| HumanEval | openai/openai_humaneval | ./LLaDA/data/humaneval/ |
| ai2_arc | allenai/ai2_arc | ./LLaDA/data/ai2_arc/ |
Datasets not listed above are already included in the ./LLaDA/data/ directory
- Demo
We have provided a quick demo to run our method on LLaDA, make sure to set model_path to your local model path.
cd LLaDA
python demo.py- Evaluation
Configuration files for the benchmarks listed above are located in ./LLaDA/configs/. To ensure a successful evaluation, you must complete data_root and model_path in the corresponding YAML file before running the script.
To run the evaluation with default settings, simply execute:
cd LLaDA
bash eval.shIf you wish to adjust the generation parameters(e.g., gen_length, steps and threshold), you have two options:
- Modify Configuration Files: For persistent settings, edit the corresponding YAML file in
./LLaDA/configs/. - Command-Line Overrides: For temporary adjustments or rapid experimentation, you can modify the command in
eval.shby adding the --gen-kwargs flag. For example:
torchrun --nproc_per_node=8 eval.py \
--config configs/gsm8k.yaml \
--method remix \
--gen-kwargs threshold=0.8,js_threshold=0.2,beta_mix=0.6 \ Note
Parameters passed via --gen-kwargs will override the values specified in the YAML configuration.
- Further Developement
To compare ReMix with other dLLM inference acceleration techniques, you can implement additional decoding functions in ./LLaDA/model/decoding.py.
- Demo
We have provided a quick demo to run our method on MMaDA, make sure to set model_path to your local model path.
cd MMaDA
python demo.py- Evaluation
We utilize lmms-eval for MMaDA evaluation.
Important
Some benchmarks (e.g., MathVista) require an auxiliary model for evaluation. Ensure your OPENAI_API_KEY and OPENAI_API_URL are properly configured before running the scripts.
Next, to run the evaluation, simply execute:
cd MMaDA
bash eval.sh
- Further Developement
To compare ReMix with other dLLM inference acceleration techniques, you can implement additional decoding functions within the MMadaModelLM class. Depending on your use case, modify the corresponding file:
- For
demo.py:./MMaDA/models/modeling_mmada.py - For evaluation:
./MMaDA/lmms_eval/lmms_eval/models/model_mmada/modeling_mmada.py
Additionally, to reproduce TPS and latency metrics or apply custom modifications, please refer to the MMaDA.generate_until method in ./MMaDA/lmms_eval/lmms_eval/models/mmada.py.
This implementation is based on the WINO codebase.