Giter Site home page Giter Site logo

brightmart / sentiment_analysis_fine_grain Goto Github PK

View Code? Open in Web Editor NEW
587.0 29.0 161.0 3.47 MB

Multi-label Classification with BERT; Fine Grained Sentiment Analysis from AI challenger

Python 15.53% Jupyter Notebook 84.47%
sentiment-analysis fine-grained-classification bert textcnn pre-train language-model multi-label-classification online text-classification

sentiment_analysis_fine_grain's Introduction

Introduction

With this repository, you will able to train Multi-label Classification with BERT,

Deploy BERT for online prediction.

You can also find the a short tutorial of how to use bert with chinese: BERT short chinese tutorial

You can find Introduction to fine grain sentiment from AI Challenger

Basic Ideas

Add something here.

Experiment on New Models

for more, check model/bert_cnn_fine_grain_model.py

Performance

Model TextCNN(No-pretrain) TextCNN(Pretrain-Finetuning) Bert(base_model_zh) Bert(base_model_zh,pre-train on corpus)
F1 Score 0.678 0.685 ADD A NUMBER HERE ADD A NUMBER HERE

Notice: F1 Score is reported on validation set

Usage

Bert for Multi-label Classificaiton [data for fine-tuning and pre-train]

export BERT_BASE_DIR=BERT_BASE_DIR/chinese_L-12_H-768_A-12
export TEXT_DIR=TEXT_DIR
nohup python run_classifier_multi_labels_bert.py   
  --task_name=sentiment_analysis   
  --do_train=true   
  --do_eval=true  
  --data_dir=$TEXT_DIR   
  --vocab_file=$BERT_BASE_DIR/vocab.txt   
  --bert_config_file=$BERT_BASE_DIR/bert_config.json  
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt   
  --max_seq_length=512   
  --train_batch_size=4   
  --learning_rate=2e-5   
  --num_train_epochs=3   
  --output_dir=./checkpoint_bert &

1.firstly, you need to download pre-trained model from google, and put to a folder(e.g.BERT_BASE_DIR)

chinese_L-12_H-768_A-12 from <a href='https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip'>bert</a>

2.secondly, you need to have training data(e.g. train.tsv) and validation data(e.g. dev.tsv), and put it under a

 folder(e.g.TEXT_DIR ). you can also download data from here <a href='https://pan.baidu.com/s/1ZS4dAdOIAe3DaHiwCDrLKw'>data to train bert for AI challenger-Sentiment Analysis</a>.
  
 it contains processed data you can run for both fine-tuning on sentiment analysis and pre-train with Bert. 
  
 it is generated by following this notebook step by step:
  
 preprocess_char.ipynb 
  
 you can generate data by yourself as long as data format is compatible with 
  
 processor SentimentAnalysisFineGrainProcessor(alias as sentiment_analysis); 


 data format:  label1,label2,label3\t here is sentence or sentences\t
 
 it only contains two columns, the first one is target(one or multi-labels), the second one is input strings.
  
 no need to tokenized.
 
 sample:"0_1,1_-2,2_-2,3_-2,4_1,5_-2,6_-2,7_-2,8_1,9_1,10_-2,11_-2,12_-2,13_-2,14_-2,15_1,16_-2,17_-2,18_0,19_-2 浦东五莲路站,老饭店福瑞轩属于上海的本帮菜,交通方便,最近又重新装修,来拨草了,饭店活动满188元送50元钱,环境干净,简单。朋友提前一天来预订包房也没有订到,只有大堂,五点半到店基本上每个台子都客满了,都是附近居民,每道冷菜量都比以前小,味道还可以,热菜烤茄子,炒河虾仁,脆皮鸭,照牌鸡,小牛排,手撕腊味花菜等每道菜都很入味好吃,会员价划算,服务员人手太少,服务态度好,要能团购更好。可以用支付宝方便"
 
 check sample data in ./BERT_BASE_DIR folder 

 for more detail, check create_model and SentimentAnalysisFineGrainProcessor from run_classifier.py 

Pre-train Bert model based on open-souced model, then do classification task

  1. generate raw data: [ADD SOMETHING HERE]

    take sure each line is a sentence. between each document there is a blank line.

    you can find generated data from zip file.

     use write_pre_train_doc() from preprocess_char.ipynb 
    
  2. generate data for pre-train stage using:

    export BERT_BASE_DIR=./BERT_BASE_DIR/chinese_L-12_H-768_A-12
    nohup python create_pretraining_data.py \
    --input_file=./PRE_TRAIN_DIR/bert_*_pretrain.txt \
    --output_file=./PRE_TRAIN_DIR/tf_examples.tfrecord \
    --vocab_file=$BERT_BASE_DIR/vocab.txt \
    --do_lower_case=True \
    --max_seq_length=512 \
    --max_predictions_per_seq=60 \
    --masked_lm_prob=0.15 \
    --random_seed=12345 \
    --dupe_factor=5 nohup_pre.out & 
    
  3. pre-train model with generated data:

    python run_pretraining.py

  4. fine-tuning

    python run_classifier.py

TextCNN

  1. download cache file of sentiment analysis(tokens are in word level)

  2. train the model:

    python train_cnn_fine_grain.py

 cache file of TextCNN model was generate by following steps from preprocess_word.ipynb. 
 
 it contains everything you need to run TextCNN.
 
 it include: processed train/validation/test set; vocabulary of word; a dict map label to index. 
 
 take train_valid_test_vocab_cache.pik and put it under folder of preprocess_word/
 
 raw data are also included in this zip file.

Pre-train TextCNN

  1. pre-train TextCNN with masked language model

    python train_cnn_lm.py

  2. fine-tuning for TextCNN

    python train_cnn_fine_grain.py

Deploy BERT for online prediction

with session and feed style you can easily deploy BERT.

online prediction with BERT, check more from here

Reference

  1. Bidirectional Encoder Representations from Transformers for Language Understanding

  2. google-research/bert

  3. pengshuang/AI-Comp

  4. AI Challenger 2018

  5. Convolutional Neural Networks for Sentence Classification

sentiment_analysis_fine_grain's People

Contributors

brightmart avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sentiment_analysis_fine_grain's Issues

RuntimeError: Attempted to use a closed Session.

Thank you for a nice example, but I bumped into the following error when executing "estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)" in run_classifier_multi_labels_bert.py.

RuntimeError: Attempted to use a closed Session.

Any suggestion?

micro_f1 calc

target lable: [0, 7, 9, 12, 18, 20, 24, 30, 32, 38, 40, 44, 48, 52, 58, 62, 64, 68, 75, 76]
predict lable: [0, 4, 8, 12, 20, 24, 32, 48, 52, 64, 68]
预测时长度小于二十的怎么恢复到20?

encoder中不使用mask?而且自注意力计算中的mask计算方式是不是有误?

在 bert_model.py 中第92行,
encoder_class = Encoder(self.d_model, self.d_k, self.d_v, self.sequence_length, self.h, self.batch_size,
self.num_layer, self.input_representation, self.input_representation,
dropout_keep_prob=self.dropout_keep_prob,
use_residual_conn=self.use_residual_conn)
参数mask为何没有赋值,意思是默认不用掩模?但编码器中掩模操作是必须的吧。

在 multi_head_attention.py中第82行,
mask = tf.expand_dims(self.mask, axis=-1) # [batch,sequence_length,1]
mask = tf.expand_dims(mask, axis=1) # [batch,1,sequence_length,1]
dot_product = dot_product + mask # [batch,h,sequence_length,1]

掩模操作怎么会是直接相加呢?

切换成英文数据集上的问题

我把中文的预训练bert换成英文的bert,就出现了下面的错误,是需要修改什么地方吗?谢谢~
ERROR:tensorflow:Error recorded from evaluation_loop: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

这段计算eval performance的代码是不是有错误

我看了一下代码逻辑,由split后的probabilities和labels来计算accuracy, 下面这个代码块中加粗部分是否应该改成label_ids_split=tf.split(label_ids,FLAGS.num_aspects,axis=-1)? 这个是否与其他人po的eval_accuracy出错有关?

`def metric_fn(per_example_loss, label_ids, logits):
#print("###metric_fn.logits:",logits.shape) # (?,80)
#predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
#print("###metric_fn.label_ids:",label_ids.shape,";predictions:",predictions.shape) # label_ids: (?,80);predictions:(?,)
logits_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
label_ids_split=tf.split(logits,FLAGS.num_aspects,axis=-1) # a list. length is num_aspects
accuracy=tf.constant(0.0,dtype=tf.float64)

    for j,logits in enumerate(logits_split): #
        #  accuracy = tf.metrics.accuracy(label_ids, predictions)

        predictions=tf.argmax(logits, axis=-1, output_type=tf.int32) # should be [batch_size,]
        label_id_=tf.cast(tf.argmax(label_ids_split[j],axis=-1),dtype=tf.int32)
        print("label_ids_split[j]:",label_ids_split[j],";predictions:",predictions,";label_id_:",label_id_)
        current_accuracy,update_op_accuracy=tf.metrics.accuracy(label_id_,predictions)
        accuracy+=tf.cast(current_accuracy,dtype=tf.float64)
    accuracy=accuracy/tf.constant(FLAGS.num_aspects,dtype=tf.float64)
    loss = tf.metrics.mean(per_example_loss)
    return {
        "eval_accuracy": (accuracy,update_op_accuracy),
        "eval_loss": loss,
    }`

F1 scores and performance

Hi

Great work!

Can I ask what f1 scores and performance did you get from running run_classifier_multi_labels_bert.py?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.