Giter Site home page Giter Site logo

text-classification's Introduction

Text-Classification

项目介绍

通过对已有标签的文本进行训练,实现新文本的分类。

更新说明

2019.3.25:项目最初是公司的一个舆情分析业务,后来参加了一些比赛又增加了一些小功能。当时只是想着把机器学习、深度学习的一些简单的模型整合在一起,锻炼一下工程能力。和一些网友交流后,觉得没必要搞一个通用型的模块(反正也没人用哈哈~)。最近刚好比较清闲,就本着越简单越好的目的把没啥用的花里胡哨的参数和函数都删了,只保留了预处理和卷积网络。

导入数据集:load_data

准备了单一标签的电商数据4000多条和多标签的司法罪名数据15000多条,数据仅供学术研究使用,禁止商业传播。

  • 单一标签的电商数据4000条为.csv格式,来源于真实电商评论,由'evaluation'和'label'两个字段组成,分别表示用户评论和正负面标签,建议pandas读取,读入后为dataframe。
  • 多标签的司法罪名数据15000条为.json格式,来源于2018‘法研杯’法律智能挑战赛(CAIL2018),由'fact'和'accusation'两个字段组成,分别表示事实陈述和罪名,读入后为列表。
from TextClassification.load_data import load_data

# 单标签
data = load_data('single')
x = data['evaluation']
y = [[i] for i in data['label']]

# 多标签
data = load_data('multiple')
x = [i['fact'] for i in data]
y = [i['accusation'] for i in data]

文本预处理:DataPreprocess.py

用于对原始文本数据做预处理,包含分词、转编码、长度统一等方法,已封装进TextClassification.py

preprocess = DataPreprocess()

# 处理文本
texts_cut = preprocess.cut_texts(texts, word_len)
preprocess.train_tokenizer(texts_cut, num_words)
texts_seq = preprocess.text2seq(texts_cut, sentence_len)

# 得到标签
preprocess.creat_label_set(labels)
labels = preprocess.creat_labels(labels)

模型训练及预测:TextClassification.py

整合预处理、网络的训练、网络的预测,demo请参考两个demo脚本

方法如下:

  • fit:输入原始文本和标签,可以在已有的模型基础上继续训练,不输入模型则重新开始训练;
  • predict:输入原始文本;
from TextClassification import TextClassification

clf = TextClassification()
texts_seq, texts_labels = clf.get_preprocess(x_train, y_train, 
                                             word_len=1, 
                                             num_words=2000, 
                                             sentence_len=50)
clf.fit(texts_seq=texts_seq,
        texts_labels=texts_labels,
        output_type=data_type,
        epochs=10,
        batch_size=64,
        model=None)

# 保存整个模块,包括预处理和神经网络
with open('./%s.pkl' % data_type, 'wb') as f:
    pickle.dump(clf, f)

# 导入刚才保存的模型
with open('./%s.pkl' % data_type, 'rb') as f:
    clf = pickle.load(f)
y_predict = clf.predict(x_test)
y_predict = [[clf.preprocess.label_set[i.argmax()]] for i in y_predict]
score = sum(y_predict == np.array(y_test)) / len(y_test)
print(score)  # 0.9288

text-classification's People

Contributors

renjunxiang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

text-classification's Issues

predict的label严重不准确的问题

道友你好~
使用DEMO的法律数据进行测试发现,predict的label几乎(95%以上)全部是一样的
我查看了前三位的索引值
prediction.argsort()[-3:][::-1]
类似:
[139,34,55]
[139,34,55]
........
[139,55,34]

基本完全一样

我发现prediction内的值是不一样的,但最高的几个值都是在固定位置
换成我自己用的数据(标签更丰富)进行测试,同样的问题,每次都还是几个固定的label
不明白是哪里出了问题

抱歉看错了

为什么二分类任务用categorical_crossentropy,多分类任务用binary_creossentory

验证

可以画个roc-auc曲线对模型的性能做个评估

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.