This repository contains the official implementation for the paper "C²DLM: Causal Concept-Guided Diffusion Large Language Models".
Autoregressive (AR) language models and Diffusion Language Models (DLMs) constitute the two principal paradigms of large language models. However, both paradigms suffer from hallucinations and unfaithful reasoning in reasoning-intensive tasks. We hypothesize that these limitations stem from a misalignment between the attention mechanism’s modeling priors on natural language and the causal priors underlying natural language. To address this issue, we propose the Causal Concept-Guided Diffusion Language Model (C²DLM). Starting from DLM's fully connected attention, C²DLM obtains concept-level causal supervision signals through a teacher model and explicitly guides attention to learn the causal relationships between concepts, which better aligns with the underlying causal prior of natural language. Simulation experiments demonstrate the significance of incorporating language priors. Compared with direct DLM fine-tuning, C²DLM improves the COT-OrderPerturb task by 12% with a about 3.2x training speedup, enhances STG tasks by 7.43% on average with about 2x speedup, and achieves an average gain of 1.31% across six downstream reasoning tasks.

- C²DLM Framework: A novel framework that injects causal priors into Diffusion Language Models to improve reasoning and reduce hallucinations.
- V-aware Re-attention: A new mechanism to align the model's attention map with the causal structure of natural language, weighted by the norms of the value vectors.
- COT-OrderPerturb Dataset: A new synthetic dataset designed to quantify the impact of causal order perturbations on model reasoning robustness.
- Improved Performance & Efficiency: C²DLM demonstrates significant gains in performance and training speed across a variety of reasoning-intensive tasks.
d1/
├── eval/ # Scripts for evaluating on downstream tasks (which is the offical repository of "d1: Scaling Reasoning in Diffusion Large Language Models via Reinforcement Learning")
│ ├── math500.py
│ ├── stg.py
│ ├── calculate_gpqa_arc_acc.py
│ ├── parse_and_get_acc.py
│ └── ...
├── eval_results/ # Stores raw model generation outputs for evaluation
│ ├── gsm8k_512_BLOCK_LENGTH32_C2DLM/
│ ├── math_512_BLOCK_LENGTH32_C2DLM/
│ └── ...
├── draw/ # Utility scripts for plotting and analysis
│ ├── bar_viz.py
│ ├── causal_llada_util.py
│ └── ...
├── train_COT_OrderPerturb/ # Implementation of our proposed COT-OrderPerturb task
│ └── experiment/
│ ├── cot_dataset_test.jsonl
│ └── multi_gpu_train.py
└── train_downstream/ # Scripts for training on general downstream tasks
└── ...
The training scripts leverage torchrun for multi-GPU execution.
To train a model on our proposed COT-OrderPerturb dataset:
cd train_COT_OrderPerturb/experiment
torchrun --nproc_per_node=[NUM_GPUS] multi_gpu_train.py [YOUR_ARGUMENTS]To train a model on other downstream reasoning tasks:
cd train_downstream
torchrun --nproc_per_node=[NUM_GPUS] multi_gpu_train.py [YOUR_ARGUMENTS]The evaluation is a two-step process: first, generate model outputs, and second, calculate the accuracy.
Use the eval.py script to produce raw outputs from your trained model. The results will be saved in the eval_results/ directory.
cd eval
python eval.py [YOUR_ARGUMENTS]Once generations are complete, use the provided parsing scripts to calculate the final accuracy. The scripts read the output folders generated in the previous step.
Each folder in eval_results contains the model's raw output with the following fields: question, prompt_input, generations, and ground_truth.
cd eval
# For GSM8K and MATH500 tasks, use the general accuracy script
python parse_and_get_acc.py --file ../eval_results/[MY_OUTPUT_FOLDER]
# For specific tasks like GPQA or ARC and MMLU_STEM, use the dedicated calculator
python calculate_gpqa_arc_mmlu_acc.py ../eval_results/[MY_GPQA_OUTPUT_FOLDER]
python calculate_sat_acc.py ../eval_results/[MY_SAT_OUTPUT_FOLDER]The draw/ directory contains scripts used to generate the attention visualizations and performance plots from our paper.
