Giter Site home page Giter Site logo

tasa's Introduction

TASA: Twin Answer Sentences Attack for Adversarial Context Generation in Question Answering

Implementation for EMNLP 2022 paper TASA by Yu Cao, Dianqi Li, Meng Fang, Tianyi Zhou, Jun Gao, Yibing Zhan and Dacheng Tao. TASA Framework

Environment

You need Python>=3.7.0

I have put some main packages required under requirements.txt in the root directory.

Your may still encounter dependency missing errors, please fix them according to the system output.

We use a single 16GB V100 GPU or a single 12GB RTX 3080Ti GPU in our experiments. At least 12GB GPU memory is needed.

Preparation

1 You need to download the following models

  1. USE model, put it under USE_PATH
  2. Small size GPT2 model, put it under GPT2_PATH
  3. BERT base uncased model, put it under BERT_PATH
  4. SpanBERT large cased model, put it under SPANBERT_PATH
  5. RoBerta base model, put it under ROBERTA_PATH
  6. GLoVe 6B 100d embedding, put it under GLOVE_PATH

2 Download QA datasets

Put them under ./data/DATASET_NAME, an example is given in ./data/squad/, where you need to edit a python file DATASET_NAME.py as the dataloader. In squad.py we use dev-v1.1.json for both training and dev sets.

For datasets from MRQA, including NewsQA, Natural Questions, HotpotQA, and TriviaQA, using ./utility_scripts/convert_mrqa_to_squad.py to convert these datasets into SQuAD format.

3 Train a sample answerable determine model

Use utility_scrips/get_no_answer_dataset.py to obtain training samples with unanswerable samples for the dataset. You will get two JSON files named DATASET_NAME_TRAIN.json_no_answer and DATASET_NAME_DEV.json_no_answer, using them to create a new directory ./data/DATASET_NAME_NO_ANSWER as the dataset path for training.

Then train the RoBerta using the following command

python train_squad.py \
--model_name_or_path ROBERTA_PATH \
--dataset_name ./data/DATASET_NAME_NO_ANSWER \
--output_dir DETERMINE_MODEL_PATH \
--version_2_with_negative 1 \

Obtain the determine model under DETERMINE_MODEL_PATH

4 Train the victim model $F(\cdot)$

Train the BERT model using the following command and obtain the trained BERT under TRAINED_BERT_PATH

python train_squad.py \
--model_name_or_path BERT_PATH \
--dataset_name ./data/DATASET_NAME \
--output_dir TRAINED_BERT_PATH \
--version_2_with_negative 0 \

Train the SpanBERT model using the following command and obtain the trained SpanBERT under TRAINED_SPANBERT_PATH

python train_squad.py \
--model_name_or_path SPANBERT_PATH \
--max_seq_length 512 \
--do_lower_case 0 \
--learning_rate 2e-5 \
--dataset_name ./data/DATASET_NAME \
--output_dir TRAINED_SPANBERT_PATH \
--version_2_with_negative 0 \

Train the BiDAF model using the following command and obtain the trained BiDAF under TRAINED_BIDAF_PATH

python train_bidaf.py \
--config_file ./bidaf/bidaf.jsonnet \
--save_path TRAINED_BIDAF_PATH \
--train_file ./data/DATASET_NAME/TRAIN_FILE.json \
--dev_file ./data/DATASET_NAME/DEV_FILE.json \
--cache_file ./data/DATASET_NAME/cache.bin \
--vocab_file ./data/TRAINED_BIDAF_PATH/vocabulary \
--passage_length_limit 800 \
--num_gradient_accumulation_steps 2 \

5 Get the coreference file for dev set

Using the following command to get the file COREFERENCE_FILE containing coreference relationship of the target attack dataset

python ./utility_scripts/get_coreference.py \
--input_file ./data/DATASET_NAME/DEV_FILE.json \
--output_file COREFERENCE_FILE \

6 Get named entity dictionary and POS tag dictionary for current dataset

Using the script ./utility/extact_entities_pos_vocab.py. You can set the input dataset JSON files within the script, and the output file paths to get ENT_DICT and POS_VOCAB_DICT, which will be used for candidate sampling in attack.

Attack

Use the follow command to attack the BERT model and obtain the adversarial samples in BERT_ATTACK_OUTPUT

python TASA.py \
--target_dataset_file ./data/DATASET_NAME/DEV_FILE.json \
--target_model TRAINED_BERT_PATH \
--target_model_type bert \
--output_dir BERT_ATTACK_OUTPUT \
--target_dataset_type squad
--ent_dict_file ENT_DICT \
--coreference_file COREFERENCE_FILE \
--pos_vocab_dict_file POS_VOCAB_DICT \
--USE_model_path USE_PATH \
--ppl_model_path GPT2_PATH \
--determine_model_path DETERMINE_MODEL_PATH \
--beam_size 5 \

Use the follow command to attack the SpanBERT model and obtain the adversarial samples in SPANBERT_ATTACK_OUTPUT

python TASA.py \
--target_dataset_file ./data/DATASET_NAME/DEV_FILE.json \
--target_model TRAINED_BERT_PATH \
--target_model_type spanbert \
--output_dir SPANBERT_ATTACK_OUTPUT \
--target_dataset_type squad
--ent_dict_file ENT_DICT \
--coreference_file COREFERENCE_FILE \
--pos_vocab_dict_file POS_VOCAB_DICT \
--USE_model_path USE_PATH \
--ppl_model_path GPT2_PATH \
--determine_model_path DETERMINE_MODEL_PATH \
--beam_size 5 \

Similarly, use the following command to attack the BiDAF model and obtain the adversarial samples in BIDAF_ATTACK_OUTPUT

python TASA.py \
--target_dataset_file ./data/DATASET_NAME/DEV_FILE.json \
--target_model TRAINED_BIDAF_PATH \
--target_model_type bidaf \
--output_dir BIDAF_ATTACK_OUTPUT \
--target_dataset_type squad
--ent_dict_file ENT_DICT \
--coreference_file COREFERENCE_FILE \
--pos_vocab_dict_file POS_VOCAB_DICT \
--USE_model_path USE_PATH \
--ppl_model_path GPT2_PATH \
--determine_model_path DETERMINE_MODEL_PATH \
--beam_size 5 \

Test adversarial samples

You an use the following command to test the performance of models on the adversarial samples

python train_squad.py \
--model_name_or_path TRAINED_BERT_PATH \
--dataset_name ./data/ADVERSARIAL_DATA \
--output_dir TRAINED_BERT_PATH \
--do_predict \

Or

python train_squad.py \
--model_name_or_path TRAINED_SPANBERT_PATH \
--dataset_name ./data/ADVERSARIAL_DATA \
--output_dir TRAINED_BERT_PATH \
--max_seq_length 512 \
--do_lower_case 0 \
--do_predict \

Or

python train_bidaf.py \
--config_file ./bidaf/bidaf.jsonnet \
--save_path debugger_train \
--dev_file ./data/ADVERSARIAL_DATA/ADVERSARIAL.json
--cache_file ./data/ADVERSARIAL_DATA/cache_test.bin
--vocab_file ./TRAINED_BIDAF_PATH/vocabulary
--passage_length_limit 800 \
--do_predict \
--model_path TRAINED_BIDAF_PATH \
--prediction_file squad_bidaf_predicitons.json \

Attacked datasets

Updated on Jun 24, 2024.

We have provided some adversarial samples after attack under ./data/attack_data. Each .zip file include the attacked samples for the corresponding dataset. Since it is a repo long time ago, it is hard for me to find out all attacking results(I put all of them on the storage servers of USYD, but I do not have the authority to access them now after my graduation).

tasa's People

Contributors

caoyu-noob avatar

Stargazers

Yi Lee avatar

Watchers

 avatar

tasa's Issues

About Attack Dataset

Hello,
I recently tried to run TASA but I've encountered some issues. First, after converting dataset to MRQA I got only one json file, then I split that json file in dev.json (10% of data) and train.json(90% data) and they were exactly in the same format as TASA/data/squad/dev-v1.1.json. Then I did the training with bert and Spanbert model and the training epochs ran successfully. I also got coreference file as mentioned. But when I tried to launch the attack by running TASA.py, I could see that some processes (add overlap token, get edit importance and get edit synonyms) are having 100% progress but adversarial samples are not generating. Rather it's showing the error when I tried the attack on TASA/data/HotpotQA/dev.json:

File "TASA.py", line 998, in main
    "coreferences": coreferences[d_i][p_i]})
IndexError: list index out of range

I also tried with your example dataset (TASA/data/squad/dev-v1.1.json) to get the adversarial sample from the attack but I got this error:

 raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'outputs'

I'm also attaching the log file here.
output.log

Please let me know the approach to solve this issue. Also it would be great if you provide me the correct dataset or data format to successfully run the attack to get adversarial samples.

About Attack Dataset

Hello, I have recently been working on replicating the experiments from TASA, but I've encountered some issues. Could you provide me with the validation datasets from NewsQA, Natural Questions, HotpotQA, and TriviaQA that have been attacked by Textfloor, T3, and ADDsent(:

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.