This repo contains the code for the Fill-in Language Model (FiLM) described in the paper FiLM: Fill-in Language Models for Any Order Generation (Shen et al., 2023).
FiLM is an any-order language model that can fill text in the middle:
Depression, loneliness and stress increase the risk of, say, drug abuse.
I tried going to the park the other day. The weather seemed nice enough for a walk. However, when I got there I started to itch. My eyes were watery and it was hard to breathe. My allergies were too bad and I had to go back home.
Its training extends the masked language modeling objective by adopting varying mask probabilities sampled from the Beta distribution to enhance the generative capabilities. At decoding time, FiLM can start with either a sequence entirely of masks or a partially complete text interspersed with masks, and it progressively replaces one mask with a predicted token at each step.
pip install -r requirements.txtpython -m data_process.get_data --dataset $dataset
bash tokenize.sh $dataset $tokenizer
bash mask.sh $dataset $tokenizerwhere $dataset can be wikitext-103, lm1b, or roc_stories; and $tokenizer can be roberta or gpt2.
python -m model.train \
--pretrained_model $pretrained_model \
--train data/$dataset/$tokenizer/train.id --valid data/$dataset/$tokenizer/valid.id \
--save_dir checkpoints/$dataset/film/$pretrained_model \
--max_tokens 20000 --accum_grad 1 --gpus 1 --precision bf16-mixed \
--lr 2e-5 \
--weight_func beta --weight_param 2.5 2.5 \
--train_steps 500000Choose your $pretrained_model and corresponding $tokenizer from the options below:
- for
roberta-baseandroberta-large, set$tokenizertoroberta; - for
gpt2,gpt2-medium,gpt2-large, andgpt2-xl, set$tokenizertogpt2.
python -m model.eval \
--input data/$dataset/$tokenizer/mask/test.mask.all.id --target data/$dataset/$tokenizer/test.id \
--checkpoint checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0 \
--order $orderpython -m eval.ppl_overall \
--len_count data/$dataset/$tokenizer/train.len_count \
--loss_masks checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0/eval/test.infill.all.$order.loss_maskswhere $order can be random, left2right, right2left, min-entropy, or max-entropy.
python -m model.generate \
--input data/$dataset/$tokenizer/mask/test.mask.span1.id \
--checkpoint checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0 \
--order $order --temp 0.8 --top_p 0.95python -m data_process.extract \
--mask data/$dataset/$tokenizer/mask/test.mask.span1.id \
--infill checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0/generate/test.infill.span1.idpython -m data_process.id2text \
--tokenizer roberta \
--file checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0/generate/test.fill.span1.idpython -m eval.rouge \
--ref data/$dataset/$tokenizer/mask/test.fill.span1.id \
--gen checkpoints/$dataset/film/$pretrained_model/lightning_logs/version_0/generate/test.fill.span1.id