This repository implements the main experiments of our TACL 2024 paper, Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding.
The code is intended solely for reproducing the experiments. We thank the authors of RISK, on which our code was based.
We tested our code in the following environment.
- OS: Debian GNU/Linux 10 (buster)
- Python: 3.8.3
- CUDA: 11.2
- GPUs: NVIDIA V100 x 2
The experiment with DeBERTa-v3-large requires a different environment.
- OS: Debian GNU/Linux 10 (buster)
- Python: 3.8.3
- CUDA: 11.2
- GPUs: NVIDIA A100 (40GB) x 2
git clone https://github.com/CyberAgentAILab/posthoc-control-moe
cd posthoc-control-moeNote
The exact versions of the libraries we used are specified in the requirements for reproducibility. For improved security, consider updating the libraries, particularly PyTorch and Transformers. However, note that we have not tested reproducibility with the updated versions.
Install dependencies to reproduce the main results.
# For conda users
conda env create -f environment.yaml
conda activate posthoc-control-moe
# For the others
pip install --force-reinstall --no-cache-dir -r requirements.txtFor the experiment with DeBERTa-v3-large, use environment_deberta.yaml or requirements_deberta.txt.
# For conda users
conda env create -f environment_deberta.yaml
conda activate posthoc-control-moe-deberta
# For the others
pip install --force-reinstall --no-cache-dir -r requirements_deberta.txtDownload the datasets from here and place them as follows.
Or you can just run gdown 'https://drive.google.com/drive/folders/1aleJytl3SAKdGBsxZbxznwusINOnTAzh?usp=share_link' --folder to download the datasets at once.
The link is kindly provided by RISK.
./dataset/
├── multinli/
│ ├── train.tsv
│ └── dev_matched.tsv
├── hans/heuristics_evaluation_set.txt
├── qqp_paws/
│ ├── qqp_train.tsv
│ ├── qqp_dev.tsv
│ └── paws_devtest.tsv
└── fever/
├── fever.train.jsonl
├── fever.dev.jsonl
├── symmetric_v0.1/fever_symmetric_generated.jsonl
└── symmetric_v0.2/fever_symmetric_test.jsonl
Original links for the datasets:
- MNLI: https://cims.nyu.edu/~sbowman/multinli/
- HANS: https://github.com/tommccoy1/hans
- QQP and PAWS: https://github.com/google-research-datasets/paws
- FEVER and FEVER-Symmetric: https://github.com/TalSchuster/FeverSymmetric
Train the mixture-of-experts and save the one that performs the best on ID dev.
Here, we specify the seed that yields near the average performance shown in the paper.
The default seed is 777, and the analyses were conducted on that seed.
mkdir -p saved_models/mnli
mkdir -p saved_models/qqp
mkdir -p saved_models/fever
# MNLI
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
--best_model_name bert_mos_e10_rs05k8_ep10_lr2e-5_8 --save
# QQP
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset qqp --batch_size 32 --epochs 10 \
--num_experts 15 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
--best_model_name bert_mos_e15_rs1k8_ep10_lr2e-5_8 --save
# FEVER
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset fever --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
--best_model_name bert_mos_e10_rs1k8_ep10_lr2e-5_8 --saveFor the DeBERTa-v3-large ablation study:
# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config_deberta.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
--save_dir saved_models/mnli \
--best_model_name deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --saveEvaluate the post-hoc control over the mixture-of-experts on OOD tests.
Some saved models are available here for those who want to check the results quickly.
Download and place them under saved_models/[task_name]/.
# HANS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
--resume bert_mos_e10_rs05k8_ep10_lr2e-5_8 --evaluate
# PAWS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset qqp --batch_size 32 --epochs 10 \
--num_experts 15 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
--resume bert_mos_e15_rs1k8_ep10_lr2e-5_8 --evaluate
# Symm. v1 and v2
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset fever --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
--resume bert_mos_e10_rs1k8_ep10_lr2e-5_8 --evaluateFor the DeBERTa-v3-large ablation study:
# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config_deberta.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
--save_dir saved_models/mnli \
--resume deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --evaluateIf you find our work useful for your research, please consider citing our paper:
@article{10.1162/tacl_a_00701,
author = {Honda, Ukyo and Oka, Tatsushi and Zhang, Peinan and Mita, Masato},
title = {Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding},
journal = {Transactions of the Association for Computational Linguistics},
volume = {12},
pages = {1268-1289},
year = {2024},
month = {10},
issn = {2307-387X},
doi = {10.1162/tacl_a_00701},
url = {https://doi.org/10.1162/tacl\_a\_00701},
eprint = {https://direct.mit.edu/tacl/article-pdf/doi/10.1162/tacl\_a\_00701/2480600/tacl\_a\_00701.pdf},
}