This is the repository for the MSc Computing project
trainer.py
: contains our Trainer classloss.py
: contains custom loss functions that we usedcleaning.py
: script to clean the Refinitiv raw datadata_analysis.py
: contains a class named Preprocessor to process data for the baseline modelembedding_viewer
: contains a class named BertEmbeddingView for model interpretation and visualisationdata.py
: contains the PyTorch dataset class to feed the inputsmodel.py
: contains the Model class we used.my_tokenizers
: contains Word-level tokeniser for the LSTM modelxgb_train.py
: contains baseline XGBoost modelmetrics.py
: contains the evluation metrics used in the projectFGM.py
: Adverserial attack training techniquesutils.py
: contains useful functions for data analsis and manipulationmain.py
: the main script that runs the experimenttransfer.py
: contains the transfer learning experimentmodels
: directory contains saved modelsemb_vis
: directory contains visualisations for model interpretabilityresult
: directory contains results of the test set
transformers 4.16.2
spacy 3.2.0
torch 1.10.0
numpy 1.19.5
pandas 1.1.5
scikit-learn 0.24.2
tqdm 4.62.3
xgboost 1.4.2
captum 0.5.0
upsetplot 0.6.1
beautifulsoup4 4.6.3
wordcloud 1.8.2.2
seaborn 0.11.2
matplotlib 3.2.2
-
To replicate the experiment, please first place the data in csv format in the current folder with a column
story
that contains the news stories. Then you can have as many controversial topics as you need, each with an individual column where1
represent the presence of the controversy. -
Run
cleaning.py
by changing the file name to your placed csv file. The resulted cleaned file is named ascleaned_2.csv
, of course you can change the name as you wish. -
Please modify
config
varaible in themain.py
to test different hyperparameters/tricks, or leave it as it is for the best model found in the project. -
To train deep learning model, run
python3 main.py
. To train baseline model, runpython3 xgb_train.py
-
Have your trained model ready, e.g.
models/ProsusAI_finbert_head_3e-06_10_512_False_None_saved_model.pt
in the models folder -
Run
python3 embedding_viewer.py
-
Have your trained model ready in models folder
-
Have target domain data source available, e.g. 'twitter_data.csv'
-
Run
python3 transfer.py
-
Please make sure you have all required libraries installed
-
MyBertModel
is the base class the the BERT model implementation. UseMyBertModel_1
for default classification head andMyBertModel_2
for the customised classification head. -
The model is saved in
models
directory.