Skip to content

Latest commit

 

History

History
127 lines (89 loc) · 4.46 KB

README.md

File metadata and controls

127 lines (89 loc) · 4.46 KB

CorpusBrain

This is the source code for paper "CorpusBrain: Pre-train a Generative Retrieval Model for Knowledge-Intensive Language Tasks".

Introduction

Overview

CorpusBrain is a pre-trained generative retrieval model, which could encode all information about the corpus in its parameters without the need of constructing additional index. Furthermore, CorpusBrain can dramatically simplify the search process and be optimized in an end-to-end manner by replacing the traditional multi-step search pipeline with a novel single-step generative model. We show that a strong generative retrieval model can be learned with a set of adequately designed pre-training tasks, and be adopted to improve a variety of downstream retrieval tasks with further fine-tuning.

Requirements and Installation

  • Python version >= 3.7
  • PyTorch version >= 1.6.0
  • fairseq version >= 0.10

Pre-training Corpus

Download Data

The KILT knowledge source can be downloaded here: kilt_knowledgesource.json (34.76GiB). It is based on the 2019/08/01 Wikipedia dump.

The BPE prefix tree (trie) from KILT Wikipedia titles that is based on the 2019/08/01 Wikipedia dump can be downloaded here: kilt_titles_trie_dict.pkl. The trie contains ~5M titles and it is used to generate document identifiers for all the KILT experiments.

The BART-Large checkpoint can be downloaded here: bart.large.tar.gz

Pre-process Data

To construct the pre-training data, run the following command:

bash scripts/preprocess_corpus.sh

Pre-training

Tokenize and Binarize Data

To tokenize and binarize the data as expected from fairseq use:

bash scripts/preprocess_fairseq.sh $DATASET_PATH $MODEL_PATH

Pre-train Model

To pre-train the model, run the following command:

bash scripts/train.sh $DATASET_PATH $NAME

Fine-tuning

CorpusBrain have the same architecture with BART-Large, and thus you can fine-tune CorpusBrain like BART on any downstream retrieval tasks by just replacing BART checkpoints with CorpusBrain's.

Download KILT Data

CorpusBrain can be fine-tuned to serve a variety of downstream retrieval tasks in KILT. To download KILT data, run the following command:

mkdir data
python scripts/download_all_kilt_data.py
python scripts/get_triviaqa_input.py

Convert KILT Data to Fairseq Format

To convert KILT data to fairseq format, run the following command:

python scripts/convert_kilt_to_fairseq_genre.py $input_filename $output_path

Tokenize, Binarize and Fine-tune

These steps are the same as the pre-training steps above.

Inference

After importing and loading the model and a prefix tree (trie), you would generate predictions with a simple call like:

import pickle

from genre.fairseq_model import GENRE
from genre.trie import Trie

# Load the prefix tree (trie)
with open("../data/kilt_titles_trie_dict.pkl", "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

# Load the model
model = GENRE.from_pretrained("models/corpus_brain").eval()

# Generate Wikipedia titles
model.sample(
    sentences=["Einstein was a German physicist."],
    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
[[{'text': 'Albert Einstein', 'score': tensor(-0.0708)},
  {'text': 'Werner Bruschke', 'score': tensor(-1.5357)},
  {'text': 'Werner von Habsburg', 'score': tensor(-1.8696)},
  {'text': 'Werner von Moltke', 'score': tensor(-2.2318)},
  {'text': 'Werner von Eichstedt', 'score': tensor(-3.0177)}]]

Acknowledgement

License

This project is under Apache License 2.0.

Citation

If you find our work useful, please consider citing our paper:

@article{chen2022corpusbrain,
  title={CorpusBrain: Pre-train a Generative Retrieval Model for Knowledge-Intensive Language Tasks},
  author={Chen, Jiangui and Zhang, Ruqing and Guo, Jiafeng and Liu, Yiqun and Fan, Yixing and Cheng, Xueqi},
  journal={arXiv preprint arXiv:2208.07652},
  year={2022}
}