Giter Site home page Giter Site logo

brokenwind / bertsimilarity Goto Github PK

View Code? Open in Web Editor NEW
483.0 7.0 69.0 2.83 MB

Computing similarity of two sentences with google's BERT algorithm。利用Bert计算句子相似度。语义相似度计算。文本相似度计算。

Python 99.14% Shell 0.86%
bert semantic nlp similarity python tensorflow

bertsimilarity's Introduction

BertSimilarity

基于Google的BERT模型来进行语义相似度计算。代码基于tensorflow 1。

1. 基本原理

简单来说就是将需要计算的相似的两个句子先拼接在一起,然后通过Bert模型获取获取整体的编码信息,接着通过全连接层将维,输出相似和不相似的概率。

1.1 模型结构

模型结构如图所示:

1.1.1 数据预处理

本文使用Bert模型计算相似度前,首先要对输入数据进行预处理,例如当要处理的文本是:

如何得知关闭借呗   想永久关闭借呗

首先要将文本按token化,切分成字数组:

[如 何 得 知 关 闭 借 呗]

[想 永 久 关 闭 借 呗]

然后将两个切分后的句子,按照如下的方式拼接:

[CLS] 如 何 得 知 关 闭 借 呗 [SEP] 想 永 久 关 闭 借 呗 [SEP]

拼接后句子中的[CLS]是一个特殊分隔符表示每个样本的开头,[SEP]是样本中每个句子的结束标记符。拼接后的样本长度是len(A)+len(B)+3,因为加上了[CLS]和[SEP]特殊标识符的,3指的是1个[CLS]和2个[SEP]。在模型中,会指定最大样本长度LEN,本次实验中指定的是30),也就处理后的样本长度和不能超过LEN,即

len(A) + len(B) + 3 <= LEN

超出最大长度的部分会被截取。然后继续对输入数据进行处理,将内容数字化。将token数组转化为token的索引数组,所有的样本会被转化为固定长度(LEN)的索引数组。[CLS]的索引是101,[SEP]的索引是102,不够最大样本长度的部分补0。然后对输入进行mask操作,样本有效部分是1,无效部分是0,0表示在训练时不会被关注。再对两个句子进行分割,分割操作是通过一个数组来标记,属于第一个句子标记为0,属于第二个句子的标记为1,不足部分也填0。

处理后结果如下图所示:

input_ids,input_mask,segment_ids都是Bert预训练模型需要的输入。

1.1.2 Bert编码阶段

Bert预训练简介

Bert模型结构如图所示:

Bert没有使用word2vec进行词嵌入,而是直接在原始语料库上训练,并且在训练过程中同时进行两个预测任务,分别是遮蔽语言模型MLM(masked language model)任务和根据上一个句子预测下一个句子的任务。Bert使用MLM来解决单向局限,其随机地从输入中遮蔽一些词块,然后通过上下文语境来预测被遮蔽的词块。从Bert的双向Transformer结构中我们可以发现MLM任务并不只是从左到右进行预测,而是它融合了遮蔽词块左右两边的句子语境,充分利用两个方向的注意力机制。同时训练数据不需要人工标注,而是在训练之前随机产生训练数据,减少了由人工标注所带来的误差,只需要通过合适的方法,对现有语料中的句子进行随机的遮掩即可得到可以用来训练的语料。很多自然语言处理任务比如问答系统和自然语言推断任务都需要对两个句子之间关系进行理解,而这种理解并不能通过语言模型直接获得。在预训练任务中加入预测下一个句子的训练可以达到理解句子关系的目的,具体做法是随机替换一些句子,然后利用上一句进行预测下一句的真伪。

利用Bert对句子对进行编码

使用已经训练好的Bert模型对句子对进行编码时,因为Bert中双向Attention机制,两个句子会相互Attention,也就是通过训练会学到两个句子的相似程度。

1.1.3 微调阶段

将句子对通过Bert预训练模型计算之后,获取两个句子的的最终编码。并对其进行0.1的Dropout,主要是为了防止模型过拟合,Dropout后的结果后边接一层全连接层,全连接层的输出维度为2,然后通过softmax计算,得到相似和不相似的概率。

2. 使用方式

2.1 数据准备

在使用前需要先下载Google预训练的Bert模型。

下载地址 https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

下载完后,解压,然后将chinese_L-12_H-768_A-12文件夹复制到 data/pretrained_model/目录中

2.2 运行说明

为了方便执行,已经编写了Shell脚本,在linux机器上直接运行start.sh,根据提示输入对应的mode(train/eval/infer)即可。

2.2.1 训练阶段

cd BertSimilarity/
sh start.sh
# 然后根据提示输入 train

执行效果:

执行日志查看:

tail -f logs/similarity_train.log

如果没有机器进行训练,我这里已经训练好了一份参数文件:

点击此处获取参数文件,提取码: fud8

解压之后,model_output复制到项目根目录的data目录下边,然后就可以进行评价和测试

2.2.2 评价阶段

cd BertSimilarity/
sh start.sh
# 然后根据提示输入 eval

执行效果:

执行日志查看:

tail -f logs/similarity_eval.log

评价结果:

2.2.3 测试阶段

cd BertSimilarity/
sh start.sh
# 然后根据提示输入 infer,待程序启动好之后,就可以分别输入两个句子,测试其相似程度

执行效果:

bertsimilarity's People

Contributors

brokenwind avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

bertsimilarity's Issues

similarity.py注释

您好,请问可以在百忙之中将similarity.py文件中的各个步骤尽可能详细的加一些注释吗,都知道是干什么的,数据维度什么的,有助于学习与理解。万分感谢 抱拳

训练卡在Saving checkpoints for 0 ,请问什么原因?

INFO:tensorflow:Saving checkpoints for 0 into ...../model.ckpt.
I1107 14:10:38.075445 140053106304832 basic_session_run_hooks.py:606] Saving checkpoints for 0 into ....../model.ckpt.

到这就停了。
Top命令,也没找到python的进程。
4核cpu. ubuntu18.04.
有什么解决办法么?

bert词向量

您好,我看项目中没有代码提到chinese_L-12_H-768_A-12词向量模型,是没有使用吗

相似度很低

用这个提供的训练好的模型,同样的测试案例,相似度很低

similarity

下载作者训练好的模型做预测,同样的测试案例,相似度却很低

运行报错

使用tf 2.5运行报错
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., tf.Variable(lambda : tf.truncated_normal([10, 40]))) when building functions. Please file a feature request if this restriction inconveniences you.

在预测两个短文本相似度时,每次输入一对短文本,预测结果非常准确。但是存在一个性能问题:预测一对短文本需要花费大约5-8秒,时间太久。然后请问该如何优化代码呢?

调用代码如下
`sim.set_mode(tf.estimator.ModeKeys.PREDICT)

predict_start_time = datetime.now()

predict = sim.predict(text_1, text_2)

predict_end_time = datetime.now()

print("预测predict花费时间:", (predict_end_time - predict_start_time).total_seconds())

score = predict[0][1]

response["score"] = str(score)
response["words"] = request_body
print('TextSimilarity time used: {} sec'.format((datetime.now() - start).total_seconds()))`

实际结果如下:
INFO:tensorflow:tokens: [CLS] 借 呗 逾 期 短 信 通 知 [SEP] 如 何 购 买 花 呗 短 信 通 知 [SEP]
INFO:tensorflow:input_ids: 101 955 1446 6874 3309 4764 928 6858 4761 102 1963 862 6579 743 5709 1446 4764 928 6858 4761 102 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:label: 0 (id = 0)
预测predict花费时间: 0:00:07.576646
TextSimilarity time used: 8.100247 sec

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.