Repo for natural language proof generation with large language model with contrastive stepwise decoding.
We use Lightning CLI and env setup by conda env create -f nlproofs.yaml. Refer to NLProofs.
Scripts for Vanilla Prompt, COT, and Select-and-Inference are placed under folder ./scripts.
python prompting.py
python cot.py
python SI.py
ConDec is the framework of contrastive decoding with hard negatives. After finetuning with MLE loss, the generator is further adjusted with hard negatives. For finetuning with MLE loss:
cd ./stepwise
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task1_stepwise_flan-t5-large.yaml
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task2_stepwise_flan-t5-large.yaml
The vanilla hard negatives are constructed by randomly substituting the intermediate nodes with premises. For finetuning with vanilla hard neagtives:
cd ./ConDec
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task1_vanilla_flan-t5-large.yaml \
--ckpt_path ../stepwise/ckpt_entailmentbank_task1/lightning_logs/version0/epoch\=499-step\=10500.ckpt
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task2_vanilla_flan-t5-large.yaml \
--ckpt_path ../stepwise/ckpt_entailmentbank_task1/lightning_logs/version0/epoch\=599-step\=12600.ckpt
The construction of enhanced hard negatives consists of three stages: training the reasoner, inference with reasoner, and filtering with checker.
Preprocess and sample the training data from training dataset
cd ./reasoner/data_sample
python datasample.py
Since the gold proof tree for task1 and task2 are the same, the acquired training data is same either from task1 or task2. The reasoner is trained with the training data:
cd ./reasoner/train
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_entailmentbank_task1.yaml
cd ./reasoner/inference
CUDA_VISIBLE_DEVICES=0 python main.py validate --config cli_entailmentbank_task1.yaml
CUDA_VISIBLE_DEVICES=0 python main.py validate --config cli_entailmentbank_task2.yaml
This will result hard negatives for task1 and task2 seperately. Sampling strategy can be random or BM25. Default is random.
cd ./reasoner/filter
python verify.py --task task1
Same with task2.
cd ./ConDec
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task1_enhanced_flan-t5-large.yaml \
--ckpt_path ../stepwise/ckpt_entailmentbank_task1/lightning_logs/version0/epoch\=499-step\=10500.ckpt
CUDA_VISIBLE_DEVICES=0 python main.py fit --config cli_task2_enhanced_flan-t5-large.yaml \
--ckpt_path ../stepwise/ckpt_entailmentbank_task1/lightning_logs/version0/epoch\=599-step\=12600.ckpt
Evaluation refer to official toolkit.