Skip to content

Kairong-Han/C-2-DLM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

C²DLM: Causal Concept-Guided Diffusion Large Language Model

This repository contains the official implementation for the paper "C²DLM: Causal Concept-Guided Diffusion Large Language Models".

Abstract

Autoregressive (AR) language models and Diffusion Language Models (DLMs) constitute the two principal paradigms of large language models. However, both paradigms suffer from hallucinations and unfaithful reasoning in reasoning-intensive tasks. We hypothesize that these limitations stem from a misalignment between the attention mechanism’s modeling priors on natural language and the causal priors underlying natural language. To address this issue, we propose the Causal Concept-Guided Diffusion Language Model (C²DLM). Starting from DLM's fully connected attention, C²DLM obtains concept-level causal supervision signals through a teacher model and explicitly guides attention to learn the causal relationships between concepts, which better aligns with the underlying causal prior of natural language. Simulation experiments demonstrate the significance of incorporating language priors. Compared with direct DLM fine-tuning, C²DLM improves the COT-OrderPerturb task by 12% with a about 3.2x training speedup, enhances STG tasks by 7.43% on average with about 2x speedup, and achieves an average gain of 1.31% across six downstream reasoning tasks. Alt text

Key Contributions

  • C²DLM Framework: A novel framework that injects causal priors into Diffusion Language Models to improve reasoning and reduce hallucinations.
  • V-aware Re-attention: A new mechanism to align the model's attention map with the causal structure of natural language, weighted by the norms of the value vectors.
  • COT-OrderPerturb Dataset: A new synthetic dataset designed to quantify the impact of causal order perturbations on model reasoning robustness.
  • Improved Performance & Efficiency: C²DLM demonstrates significant gains in performance and training speed across a variety of reasoning-intensive tasks.

Method

Alt text

File Structure

d1/
├── eval/                  # Scripts for evaluating on downstream tasks (which is the offical repository of "d1: Scaling Reasoning in Diffusion Large Language Models via Reinforcement Learning")
│   ├── math500.py
│   ├── stg.py
│   ├── calculate_gpqa_arc_acc.py
│   ├── parse_and_get_acc.py
│   └── ...
├── eval_results/          # Stores raw model generation outputs for evaluation
│   ├── gsm8k_512_BLOCK_LENGTH32_C2DLM/
│   ├── math_512_BLOCK_LENGTH32_C2DLM/
│   └── ...
├── draw/                  # Utility scripts for plotting and analysis
│   ├── bar_viz.py
│   ├── causal_llada_util.py
│   └── ...
├── train_COT_OrderPerturb/  # Implementation of our proposed COT-OrderPerturb task
│   └── experiment/
│       ├── cot_dataset_test.jsonl
│       └── multi_gpu_train.py
└── train_downstream/      # Scripts for training on general downstream tasks
    └── ...

Training

The training scripts leverage torchrun for multi-GPU execution.

Training on COT-OrderPerturb

To train a model on our proposed COT-OrderPerturb dataset:

cd train_COT_OrderPerturb/experiment
torchrun --nproc_per_node=[NUM_GPUS] multi_gpu_train.py [YOUR_ARGUMENTS]

Training on Downstream Datasets

To train a model on other downstream reasoning tasks:

cd train_downstream
torchrun --nproc_per_node=[NUM_GPUS] multi_gpu_train.py [YOUR_ARGUMENTS]

Evaluation

The evaluation is a two-step process: first, generate model outputs, and second, calculate the accuracy.

1. Generate Model Outputs

Use the eval.py script to produce raw outputs from your trained model. The results will be saved in the eval_results/ directory.

cd eval
python eval.py [YOUR_ARGUMENTS]

2. Calculate Accuracy

Once generations are complete, use the provided parsing scripts to calculate the final accuracy. The scripts read the output folders generated in the previous step.

Each folder in eval_results contains the model's raw output with the following fields: question, prompt_input, generations, and ground_truth.

cd eval

# For GSM8K and MATH500 tasks, use the general accuracy script
python parse_and_get_acc.py --file ../eval_results/[MY_OUTPUT_FOLDER]

# For specific tasks like GPQA or ARC and MMLU_STEM, use the dedicated calculator
python calculate_gpqa_arc_mmlu_acc.py ../eval_results/[MY_GPQA_OUTPUT_FOLDER]
python calculate_sat_acc.py ../eval_results/[MY_SAT_OUTPUT_FOLDER]

Visualization

The draw/ directory contains scripts used to generate the attention visualizations and performance plots from our paper.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages