This repository contains the code for the paper: Improving Factuality of Abstractive Summarization without Sacrificing Summary Quality ACL 2023
Improving factual consistency of abstractive summarization has been a widely studied topic. However, most of the prior works on training factuality-aware models have ignored the negative effect it has on summary quality. We propose EFactSum (i.e., Effective Factual Summarization), a candidate summary generation and ranking technique to improve summary factuality without sacrificing summary quality. We show that using a contrastive learning framework (Liu et al. 2022) with our refined candidate summaries leads to significant gains on both factuality and similarity-based metrics. Specifically, we propose a ranking strategy in which we effectively combine two metrics, thereby preventing any conflict during training. Models trained using our approach show up to 6 points of absolute improvement over the base model with respect to FactCC on XSUM and 11 points on CNN/DM, without negatively affecting either similarity-based metrics or absractiveness.
For each training article, we sample 16 summaries using the base model (PEGASUS/BART) and 16 summaries using CLIFF. Each summary is scored along two axes - Rouge and FactCC, and out of these 32 summaries 6 are chosen based on the ranking strategy (refer to section 2). You can download the data from these links XSUM, CNN/DM
The script for generating the summaries using our trained models is in the decoding folder.
For XSUM run
python3 run_sum.py \
--model_path tanay/efactsum-pegasus-xsum \
--source ../outputs/xsum.test.source.txt \
--output_dir $OUTPUT_DIR \
--batch_size 2 \
--max_length 512 \
--gen_max_len 62 \
--gen_min_len 11 \
--num_beams 8 \
--length_penalty 0.6
For CNNDM run
python3 run_sum.py \
--model_path tanay/efactsum-bart-cnndm \
--source ../outputs/xsum.test.source.txt \
--output_dir $OUTPUT_DIR \
--batch_size 2 \
--max_length 1024 \
--gen_max_len 140 \
--gen_min_len 55 \
--num_beams 4 \
--length_penalty 2
We summarize the outputs from our models below
Model | Source | Model Output | Reference Output | |
---|---|---|---|---|
CNNDM | tanay/efactsum-bart-cnndm |
cnndm.test.source | cnndm.test.ours | cnndm.test.target |
XSum | tanay/efactsum-pegasus-xsum |
xsum.test.source | xsum.test.ours | xsum.test.target |
You can load our trained models from Huggingface Transformers.
Our model checkpoint on CNNDM (tanay/efactsum-bart-cnndm
) is a standard BART model (i.e., BartForConditionalGeneration
) while our model checkpoint on XSum (tanay/efactsum-pegasus-xsum
) is a standard Pegasus model (i.e., PegasusForConditionalGeneration
).
Example usage with HuggingFace:
from transformers import BartTokenizer, PegasusTokenizer
from transformers import BartForConditionalGeneration, PegasusForConditionalGeneration
IS_CNNDM = True
max_length = 1024 if IS_CNNDM else 512
ARTICLE_TO_SUMMARIZE = "firefighters responded to cries for help - from two parrots. the crew scoured a burning home in boise, idaho, searching \
for people shouting 'help!' and 'fire!' eventually, to their surprise, they found a pair of squawking birds. \
scroll down for video. cry for help! this is one of the two parrots who were found in a burning home after calling for help. \
the tropical creatures appeared to have been alone when flames began to sweep the property. but they seemed to know what to do. \
treatment: the officials treated the birds with oxygen masks and both are expected to survive. according to kboi, the cause of the officers \
managed to contain the fire to just one room. it is being investigated and no people were found inside. officials have yet to track down the birds' owners. .\
"
if IS_CNNDM:
model = BartForConditionalGeneration.from_pretrained('tanay/efactsum-bart-cnndm')
tokenizer = BartTokenizer.from_pretrained('tanay/efactsum-bart-cnndm')
else:
model = PegasusForConditionalGeneration.from_pretrained('tanay/efactsum-pegasus-xsum')
tokenizer = PegasusTokenizer.from_pretrained('tanay/efactsum-pegasus-xsum')
article = ARTICLE_TO_SUMMARIZE.lower()
inputs = tokenizer([article], max_length=max_length, return_tensors="pt", truncation=True)
# Generate Summary
summary_ids = model.generate(inputs["input_ids"])
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])