Giter Site home page Giter Site logo

jbdatascience / ctrl-sum Goto Github PK

View Code? Open in Web Editor NEW

This project forked from salesforce/ctrl-sum

0.0 0.0 0.0 1.85 MB

Resources for the "CTRLsum: Towards Generic Controllable Text Summarization" paper

Home Page: https://arxiv.org/abs/2012.04281

License: BSD 3-Clause "New" or "Revised" License

Python 91.93% Shell 8.07%

ctrl-sum's Introduction

CTRLsum

This is PyTorch implementation of the paper:

CTRLsum: Towards Generic Controllable Text Summarization
Junxian He, Wojciech Kryściński, Bryan McCann, Nazneen Rajani, Caiming Xiong
arXiv 2020

This repo includes instructions for using pretrained CTRLsum models as well as training new models.

CTRLsum is a generic controllable summarization system to manipulate text summaries given control tokens in the form of keywords or prefix. CTRLsum is also able to achieve strong (e.g. state-of-the-art on CNN/Dailymail) summarization performance in an uncontrolled setting.

Model checkpoints

Dataset Dowload
CNN/DailyMail download (.tar.gz)
arXiv download (.tar.gz)
BIGPATENT download (.tar.gz)

Dependencies

The code requires Python 3, PyTorch (>=1.4.0), and fairseq (the code is tested on this commit)

Install dependencies:

# manually install fairseq
git clone https://github.com/pytorch/fairseq

# this repo is tested on a commit of fairseq from May 2020:
# fad3cf0769843e767155f4d0af18a61b9a804f59
cd fairseq
git reset --hard fad3cf07

# the BART interface in fairseq does not support prefix-constrained decoding
# as of creating this README, thus we need to make several modifications to 
# fairseq before installing it
cp ../ctrlsum/fairseq_task.py fairseq/tasks/fairseq_task.py
cp ../ctrlsum/sequence_generator.py farseq/
cp ../ctrlsum/hub_interface.py fairseq/models/bart/

# install fairseq
pip install --editable ./

cd ..

# install other requirements
pip install -r requirements.txt

Example Usage of Pretrained Models

Generate summaries in an interactive way, users can specify the control tokens (keywords, prompts, or both):

CUDA_VISIBLE_DEVICES=xx python scripts/generate_bart_interactive.py --exp [checkpoint directory] \
	--dataset example_dataset \
	--src test.oraclewordnssource

The command above reads source articles from datasets/example_dataset/test.oraclewordnssource, users can then interact with the system in the commandline by inputting the id of examples to be shown, as well as the control tokens:

ctrlsum

Generate summaries from a file which includes keywords:

# the following command generates summaries from `datasets/example_dataset/test.oraclewordnssource`
# the input data format is concatenated keywords and source with sep token, please refer to the 
# given example data files for examples
# the predicted summaries are saved into the checkpoint directory
CUDA_VISIBLE_DEVICES=xx python scripts/generate_bart.py --exp [checkpoint directory] \
	--dataset example_dataset \
	--src test.oraclewordnssource 

Train CTRLsum

Data Processing

Prepare your data files into datasets/[dataset name], which should consist of six data files as [train/val/test].[source/target]. These data files are raw text with each row representing one example. We take cnndm dataset as an example to preprocess the dataset (see here for instructions to obtain the cnndm dataset):

# this command runs the preprocessing pipeline including tokenization, truncation, and 
# keywords extraction. It will generate all required data files to train CTRLsum into 
# `datasets/cnndm`. Example obtained files can be found in `datasets/example_dataset`
# Some optional arguments can be found in preprocess.py
python scripts/preprocess.py cnndm --mode pipeline

# gpt2 encoding
bash scripts/gpt2_encode.sh cnndm

# binarize dataset for fairseq
bash scripts/binarize_dataset.ssh cnndm

For the generated files in the datasets/cnndm, the suffix oracleword represents the keywords (after keyword dropout) file, oraclewordsource represents the concatenated keywords and source. oraclewordns represents the original keywords without keyword dropout. The .jsonl files are potentially used to train the tagger later.

Train the summarization model on multiple GPUs:

bash scripts/train_bart.sh -g [GPUs] -d [dataset name]

GPUs are GPU ids separated by ,. All our experiments are on 8 GPUs accumulating 8 gradient steps, resulting in an effective batch size of 1024x8x8 tokens in total. You propably need to increase the update_freq variable in train_bart.sh if you use less GPUs to match the effective batch size. The saved models are in dir checkpoint. The training arguments can be found in train_bart.sh.

Train the keyword tagger (optional):

Note that the keyword tagger is required only in uncontrolled summarization setting and certain control settings which require automatic keywords (like length control in the paper)

# this requires to give 4 gpus for training by default,
# you need to change the --nproc_per_node value if you 
# train with different number of gpus
bash scripts/train_seqlabel.sh -g [GPUs] -d [dataset name]

The effective batch size we used for different datasets can be found in the training script as number of gpus x batch x uddate_freq

Evaluate CTRLsum

Here we include evaluation for uncontrolled summarization settings.

Obtain automatic keywords from a trained tagger:

# run prediction from the tagger which outputs confidence values for every token
# `checkpoint directory` is the directory that contains the `pytorch_model.bin` checkpoint.
# the results are saved in the checkpoint directory as test_predictions.txt
bash scripts/train_seqlabel.sh -g [GPUs] -d [dataset name] -p [checkpoint directory]


# obtain keywords by selecting confident words, `threshold, maximum-word, and summary-size` 
# are three hyperparameters in this step, please check Appendix A in the paper for specific
# values we used for different datasets, the performance is relatively robust
# this command will yield a file `.predwordsource` in `datasets/[dataset name]` which can be
# used as input to the summarization model to obtain uncontrolled summaries
python scripts/preprocess.py [dataset name] \
		--split test \
		--mode process_tagger_prediction \
		--tag-pred [the tagger prediction file path, named as test_predictions.txt] \
		--threshold [confidence threshold] \
		--maximum-word [maximum number of keywords] \
		--summary-size [number of sentences from which to identify keywords]

Metrics:

We report ROUGE scores and BERTScore in the paper. The ROUGE scores in the paper are computed using files2rouge which is a wrapper of a wrapper of the original ROUGE perl scripts. Please refer to scripts/test_bart.sh for our evaluation script:

# you will need the Stanford corenlp java toolkit to run this, we use it for tokenization
# this script computes ROUGE and (optionally) BERTScore.
bash scripts/test_bart.sh -g [GPUs] -s [source file name, NOT full path] -d [dataset] -p [ctrlsum checkpoint directory]

Citation

@article{he2020ctrlsum,
title={{\{}CTRL{\}}sum: Towards Generic Controllable Text Summarization},
author={He, Junxian and Kry{\'s}ci{\'n}ski, Wojciech and McCann, Bryan and Rajani, Nazneen and Xiong, Caiming},
journal={arXiv},
year={2020}
}

ctrl-sum's People

Contributors

jxhe avatar muggin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.