Giter Site home page Giter Site logo

kaushikpunyamurthula / selective_prediction_mtl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from arutselvan/selective_prediction_mtl

0.0 0.0 0.0 23.74 MB

Investigation of how sampling strategies affect Selective Prediction performance in Multi Task Learning

Python 100.00%

selective_prediction_mtl's Introduction

Investigating the Impact of Multi-Task Learning strategies on Selective Prediction

Clone the repo

git clone https://github.com/Arutselvan/selective_prediction_mtl

Change the current directory to the cloned repository

cd selective_prediction_mtl

The datasets for both train and eval are present in the dataset folder.

To convert the dataset into QA format [Question, Context, Answer]:

python code/squad_converter.py

Both training files, the combined dataset (main.json) and evaluation samples will be generated and stored in the dataset folder.

To perform training (Heterogeneous Sampling) and evaluation on all datasets' eval files, run the below command (from the root directory of the repo)

python code/run_qa.py --model_name_or_path bert-base-cased --do_train --train_file main.json --validation_files "['snli_squad_eval.json', 'swag_squad_eval.json', 'csqa_squad_eval.json', 'anli_squad_eval.json', 'siqa_squad_eval.json']"  --max_seq_length 256 --output_dir ./output-hetero --overwrite_output_dir --num_train_epochs 5 --evaluation_strategy epoch --per_device_train_batch_size 16 --per_device_eval_batch_size 32

To perform training (Homogenous Sampling) and evaluation on all datasets' eval files, run the below command (from the root directory of the repo)

python code/run_qa.py --model_name_or_path bert-base-cased --do_train --train_file main.json --validation_files "['snli_squad_eval.json', 'swag_squad_eval.json', 'csqa_squad_eval.json', 'anli_squad_eval.json', 'siqa_squad_eval.json']"  --max_seq_length 256 --output_dir ./output-homo --overwrite_output_dir --num_train_epochs 5 --evaluation_strategy epoch --per_device_train_batch_size 16 --per_device_eval_batch_size 32 --sampling Homogenous

Note: The evaluation accuracy metrics won't output anything/will output zero because the the predictions format was changed for the purpose of selective prediction.

The predictions will be of this format:

"csqa_0": {
    "prediction": "fail to work",
    "maxProb": "0.9998784"
 }

To create files with exact match and max probability for both homogenous and heterogenous models' predictions, run

python code/evaluate.py

This will create another .json file of format:

{
    "expected_prediction": "Levin was very successful at running the store",
    "prediction": "Levin was very successful at running the store",
    "correct": true,
    "maxProb": "0.99999547"
}

Note: For the above command to execute correctly, all steps needs to followed exactly as the folder paths are hardcoded.

To view plots of selective prediction metrics for all datasets (for both homogenous and heterogenous sampling), run

python code/evaluate_sp.py

This will create plot files comparing various selective prediction metrics of the models for all datasets. The plots are shown one by one for each dataset. (Close the current plot to view the next) Note: The texts of the graph created might be cluttered on some devices. Make the graph full screen to view it clearly.

Example plot:

snli_sp_metrics_mtl

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.