Giter Site home page Giter Site logo

ymcui / chinese-electra Goto Github PK

View Code? Open in Web Editor NEW
1.4K 26.0 171.0 431 KB

Pre-trained Chinese ELECTRA(中文ELECTRA预训练模型)

Home Page: http://electra.hfl-rc.com

License: Apache License 2.0

Python 100.00%
nlp bert electra pre-trained-model chinese chinese-electra language-model pytorch tensorflow

chinese-electra's Introduction

中文说明 | English



GitHub

谷歌与斯坦福大学共同研发的最新预训练模型ELECTRA因其小巧的模型体积以及良好的模型性能受到了广泛关注。 为了进一步促进中文预训练模型技术的研究与发展,哈工大讯飞联合实验室基于官方ELECTRA训练代码以及大规模的中文数据训练出中文ELECTRA预训练模型供大家下载使用。 其中ELECTRA-small模型可与BERT-base甚至其他同等规模的模型相媲美,而参数量仅为BERT-base的1/10。

本项目基于谷歌&斯坦福大学官方的ELECTRA:https://github.com/google-research/electra


中文LERT | 中英文PERT | 中文MacBERT | 中文ELECTRA | 中文XLNet | 中文BERT | 知识蒸馏工具TextBrewer | 模型裁剪工具TextPruner

查看更多哈工大讯飞联合实验室(HFL)发布的资源:https://github.com/ymcui/HFL-Anthology

新闻

2023/3/28 开源了中文LLaMA&Alpaca大模型,可快速在PC上部署体验,查看:https://github.com/ymcui/Chinese-LLaMA-Alpaca

2022/10/29 我们提出了一种融合语言学信息的预训练模型LERT。查看:https://github.com/ymcui/LERT

2022/3/30 我们开源了一种新预训练模型PERT。查看:https://github.com/ymcui/PERT

2021/12/17 哈工大讯飞联合实验室推出模型裁剪工具包TextPruner。查看:https://github.com/airaria/TextPruner

2021/10/24 哈工大讯飞联合实验室发布面向少数民族语言的预训练模型CINO。查看:https://github.com/ymcui/Chinese-Minority-PLM

2021/7/21 由哈工大SCIR多位学者撰写的《自然语言处理:基于预训练模型的方法》已出版,欢迎大家选购。

2020/12/13 基于大规模法律文书数据,我们训练了面向司法领域的中文ELECTRA系列模型,查看模型下载司法任务效果

点击这里查看历史新闻 2020/10/22 ELECTRA-180g已发布,增加了CommonCrawl的高质量数据,查看[模型下载](#模型下载)。

2020/9/15 我们的论文"Revisiting Pre-Trained Models for Chinese Natural Language Processing"Findings of EMNLP录用为长文。

2020/8/27 哈工大讯飞联合实验室在通用自然语言理解评测GLUE中荣登榜首,查看GLUE榜单新闻

2020/5/29 Chinese ELECTRA-large/small-ex已发布,请查看模型下载,目前只提供Google Drive下载地址,敬请谅解。

2020/4/7 PyTorch用户可通过🤗Transformers加载模型,查看快速加载

2020/3/31 本目录发布的模型已接入飞桨PaddleHub,查看快速加载

2020/3/25 Chinese ELECTRA-small/base已发布,请查看模型下载

内容导引

章节 描述
简介 介绍ELECTRA基本原理
模型下载 中文ELECTRA预训练模型下载
快速加载 介绍了如何使用🤗TransformersPaddleHub快速加载模型
基线系统效果 中文基线系统效果:阅读理解、文本分类等
使用方法 模型的详细使用方法
FAQ 常见问题答疑
引用 本目录的技术报告

简介

ELECTRA提出了一套新的预训练框架,其中包括两个部分:GeneratorDiscriminator

  • Generator: 一个小的MLM,在[MASK]的位置预测原来的词。Generator将用来把输入文本做部分词的替换。
  • Discriminator: 判断输入句子中的每个词是否被替换,即使用Replaced Token Detection (RTD)预训练任务,取代了BERT原始的Masked Language Model (MLM)。需要注意的是这里并没有使用Next Sentence Prediction (NSP)任务。

在预训练阶段结束之后,我们只使用Discriminator作为下游任务精调的基模型。

更详细的内容请查阅ELECTRA论文:ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

模型下载

本目录中包含以下模型,目前仅提供TensorFlow版本权重。

  • ELECTRA-large, Chinese: 24-layer, 1024-hidden, 16-heads, 324M parameters
  • ELECTRA-base, Chinese: 12-layer, 768-hidden, 12-heads, 102M parameters
  • ELECTRA-small-ex, Chinese: 24-layer, 256-hidden, 4-heads, 25M parameters
  • ELECTRA-small, Chinese: 12-layer, 256-hidden, 4-heads, 12M parameters

大语料版(新版,180G数据)

模型简称 Google下载 百度网盘下载 压缩包大小
ELECTRA-180g-large, Chinese TensorFlow TensorFlow(密码2v5r) 1G
ELECTRA-180g-base, Chinese TensorFlow TensorFlow(密码3vg1) 383M
ELECTRA-180g-small-ex, Chinese TensorFlow TensorFlow(密码93n8) 92M
ELECTRA-180g-small, Chinese TensorFlow TensorFlow(密码k9iu) 46M

基础版(原版,20G数据)

模型简称 Google下载 百度网盘下载 压缩包大小
ELECTRA-large, Chinese TensorFlow TensorFlow(密码1e14) 1G
ELECTRA-base, Chinese TensorFlow TensorFlow(密码f32j) 383M
ELECTRA-small-ex, Chinese TensorFlow TensorFlow(密码gfb1) 92M
ELECTRA-small, Chinese TensorFlow TensorFlow(密码1r4r) 46M

司法领域版(new)

模型简称 Google下载 百度网盘下载 压缩包大小
legal-ELECTRA-large, Chinese TensorFlow TensorFlow(密码q4gv) 1G
legal-ELECTRA-base, Chinese TensorFlow TensorFlow(密码8gcv) 383M
legal-ELECTRA-small, Chinese TensorFlow TensorFlow(密码kmrj) 46M

PyTorch/TF2版本

如需PyTorch版本,请自行通过🤗Transformers提供的转换脚本convert_electra_original_tf_checkpoint_to_pytorch.py进行转换。如需配置文件可进入到本目录下的config文件夹中查找。

python transformers/src/transformers/convert_electra_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ./path-to-large-model/ \
--config_file ./path-to-large-model/discriminator.json \
--pytorch_dump_path ./path-to-output/model.bin \
--discriminator_or_generator discriminator

或者通过huggingface官网直接下载PyTorch版权重:https://huggingface.co/hfl

方法:点击任意需要下载的model → 拉到最下方点击"List all files in model" → 在弹出的小框中下载bin和json文件。

使用须知

**大陆境内建议使用百度网盘下载点,境外用户建议使用谷歌下载点。 以TensorFlow版ELECTRA-small, Chinese为例,下载完毕后对zip文件进行解压得到如下文件。

chinese_electra_small_L-12_H-256_A-4.zip
    |- electra_small.data-00000-of-00001    # 模型权重
    |- electra_small.meta                   # 模型meta信息
    |- electra_small.index                  # 模型index信息
    |- vocab.txt                            # 词表
    |- discriminator.json                   # 配置文件:discriminator(若没有可从本repo中的config目录获取)
    |- generator.json                       # 配置文件:generator(若没有可从本repo中的config目录获取)

训练细节

我们采用了大规模中文维基以及通用文本训练了ELECTRA模型,总token数达到5.4B,与RoBERTa-wwm-ext系列模型一致。词表方面沿用了谷歌原版BERT的WordPiece词表,包含21,128个token。其他细节和超参数如下(未提及的参数保持默认):

  • ELECTRA-large: 24层,隐层1024,16个注意力头,学习率1e-4,batch96,最大长度512,训练2M步
  • ELECTRA-base: 12层,隐层768,12个注意力头,学习率2e-4,batch256,最大长度512,训练1M步
  • ELECTRA-small-ex: 24层,隐层256,4个注意力头,学习率5e-4,batch384,最大长度512,训练2M步
  • ELECTRA-small: 12层,隐层256,4个注意力头,学习率5e-4,batch1024,最大长度512,训练1M步

快速加载

使用Huggingface-Transformers

Huggingface-Transformers 2.8.0版本已正式支持ELECTRA模型,可通过如下命令调用。

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME) 

其中MODEL_NAME对应列表如下:

模型名 组件 MODEL_NAME
ELECTRA-180g-large, Chinese discriminator hfl/chinese-electra-180g-large-discriminator
ELECTRA-180g-large, Chinese generator hfl/chinese-electra-180g-large-generator
ELECTRA-180g-base, Chinese discriminator hfl/chinese-electra-180g-base-discriminator
ELECTRA-180g-base, Chinese generator hfl/chinese-electra-180g-base-generator
ELECTRA-180g-small-ex, Chinese discriminator hfl/chinese-electra-180g-small-ex-discriminator
ELECTRA-180g-small-ex, Chinese generator hfl/chinese-electra-180g-small-ex-generator
ELECTRA-180g-small, Chinese discriminator hfl/chinese-electra-180g-small-discriminator
ELECTRA-180g-small, Chinese generator hfl/chinese-electra-180g-small-generator
ELECTRA-large, Chinese discriminator hfl/chinese-electra-large-discriminator
ELECTRA-large, Chinese generator hfl/chinese-electra-large-generator
ELECTRA-base, Chinese discriminator hfl/chinese-electra-base-discriminator
ELECTRA-base, Chinese generator hfl/chinese-electra-base-generator
ELECTRA-small-ex, Chinese discriminator hfl/chinese-electra-small-ex-discriminator
ELECTRA-small-ex, Chinese generator hfl/chinese-electra-small-ex-generator
ELECTRA-small, Chinese discriminator hfl/chinese-electra-small-discriminator
ELECTRA-small, Chinese generator hfl/chinese-electra-small-generator

司法领域版本:

模型名 组件 MODEL_NAME
legal-ELECTRA-large, Chinese discriminator hfl/chinese-legal-electra-large-discriminator
legal-ELECTRA-large, Chinese generator hfl/chinese-legal-electra-large-generator
legal-ELECTRA-base, Chinese discriminator hfl/chinese-legal-electra-base-discriminator
legal-ELECTRA-base, Chinese generator hfl/chinese-legal-electra-base-generator
legal-ELECTRA-small, Chinese discriminator hfl/chinese-legal-electra-small-discriminator
legal-ELECTRA-small, Chinese generator hfl/chinese-legal-electra-small-generator

使用PaddleHub

依托PaddleHub,我们只需一行代码即可完成模型下载安装,十余行代码即可完成文本分类、序列标注、阅读理解等任务。

import paddlehub as hub
module = hub.Module(name=MODULE_NAME)

其中MODULE_NAME对应列表如下:

模型名 MODULE_NAME
ELECTRA-base, Chinese chinese-electra-base
ELECTRA-small, Chinese chinese-electra-small

基线系统效果

我们将ELECTRA-small/baseBERT-baseBERT-wwmBERT-wwm-extRoBERTa-wwm-extRBT3进行了效果对比,包括以下六个任务:

对于ELECTRA-small/base模型,我们使用原论文默认的3e-41e-4的学习率。 需要注意的是,我们没有针对任何任务进行参数精调,所以通过调整学习率等超参数可能获得进一步性能提升。 为了保证结果的可靠性,对于同一模型,我们使用不同随机种子训练10遍,汇报模型性能的最大值和平均值(括号内为平均值)。

简体中文阅读理解:CMRC 2018

CMRC 2018数据集是哈工大讯飞联合实验室发布的中文机器阅读理解数据。 根据给定问题,系统需要从篇章中抽取出片段作为答案,形式与SQuAD相同。 评价指标为:EM / F1

模型 开发集 测试集 挑战集 参数量
BERT-base 65.5 (64.4) / 84.5 (84.0) 70.0 (68.7) / 87.0 (86.3) 18.6 (17.0) / 43.3 (41.3) 102M
BERT-wwm 66.3 (65.0) / 85.6 (84.7) 70.5 (69.1) / 87.4 (86.7) 21.0 (19.3) / 47.0 (43.9) 102M
BERT-wwm-ext 67.1 (65.6) / 85.7 (85.0) 71.4 (70.0) / 87.7 (87.0) 24.0 (20.0) / 47.3 (44.6) 102M
RoBERTa-wwm-ext 67.4 (66.5) / 87.2 (86.5) 72.6 (71.4) / 89.4 (88.8) 26.2 (24.6) / 51.0 (49.1) 102M
RBT3 57.0 / 79.0 62.2 / 81.8 14.7 / 36.2 38M
ELECTRA-small 63.4 (62.9) / 80.8 (80.2) 67.8 (67.4) / 83.4 (83.0) 16.3 (15.4) / 37.2 (35.8) 12M
ELECTRA-180g-small 63.8 / 82.7 68.5 / 85.2 15.1 / 35.8 12M
ELECTRA-small-ex 66.4 / 82.2 71.3 / 85.3 18.1 / 38.3 25M
ELECTRA-180g-small-ex 68.1 / 85.1 71.8 / 87.2 20.6 / 41.7 25M
ELECTRA-base 68.4 (68.0) / 84.8 (84.6) 73.1 (72.7) / 87.1 (86.9) 22.6 (21.7) / 45.0 (43.8) 102M
ELECTRA-180g-base 69.3 / 87.0 73.1 / 88.6 24.0 / 48.6 102M
ELECTRA-large 69.1 / 85.2 73.9 / 87.1 23.0 / 44.2 324M
ELECTRA-180g-large 68.5 / 86.2 73.5 / 88.5 21.8 / 42.9 324M

繁体中文阅读理解:DRCD

DRCD数据集由****台达研究院发布,其形式与SQuAD相同,是基于繁体中文的抽取式阅读理解数据集。 评价指标为:EM / F1

模型 开发集 测试集 参数量
BERT-base 83.1 (82.7) / 89.9 (89.6) 82.2 (81.6) / 89.2 (88.8) 102M
BERT-wwm 84.3 (83.4) / 90.5 (90.2) 82.8 (81.8) / 89.7 (89.0) 102M
BERT-wwm-ext 85.0 (84.5) / 91.2 (90.9) 83.6 (83.0) / 90.4 (89.9) 102M
RoBERTa-wwm-ext 86.6 (85.9) / 92.5 (92.2) 85.6 (85.2) / 92.0 (91.7) 102M
RBT3 76.3 / 84.9 75.0 / 83.9 38M
ELECTRA-small 79.8 (79.4) / 86.7 (86.4) 79.0 (78.5) / 85.8 (85.6) 12M
ELECTRA-180g-small 83.5 / 89.2 82.9 / 88.7 12M
ELECTRA-small-ex 84.0 / 89.5 83.3 / 89.1 25M
ELECTRA-180g-small-ex 87.3 / 92.3 86.5 / 91.3 25M
ELECTRA-base 87.5 (87.0) / 92.5 (92.3) 86.9 (86.6) / 91.8 (91.7) 102M
ELECTRA-180g-base 89.6 / 94.2 88.9 / 93.7 102M
ELECTRA-large 88.8 / 93.3 88.8 / 93.6 324M
ELECTRA-180g-large 90.1 / 94.8 90.5 / 94.7 324M

自然语言推断:XNLI

在自然语言推断任务中,我们采用了XNLI数据,需要将文本分成三个类别:entailmentneutralcontradictory。 评价指标为:Accuracy

模型 开发集 测试集 参数量
BERT-base 77.8 (77.4) 77.8 (77.5) 102M
BERT-wwm 79.0 (78.4) 78.2 (78.0) 102M
BERT-wwm-ext 79.4 (78.6) 78.7 (78.3) 102M
RoBERTa-wwm-ext 80.0 (79.2) 78.8 (78.3) 102M
RBT3 72.2 72.3 38M
ELECTRA-small 73.3 (72.5) 73.1 (72.6) 12M
ELECTRA-180g-small 74.6 74.6 12M
ELECTRA-small-ex 75.4 75.8 25M
ELECTRA-180g-small-ex 76.5 76.6 25M
ELECTRA-base 77.9 (77.0) 78.4 (77.8) 102M
ELECTRA-180g-base 79.6 79.5 102M
ELECTRA-large 81.5 81.0 324M
ELECTRA-180g-large 81.2 80.4 324M

情感分析:ChnSentiCorp

在情感分析任务中,二分类的情感分类数据集ChnSentiCorp。 评价指标为:Accuracy

模型 开发集 测试集 参数量
BERT-base 94.7 (94.3) 95.0 (94.7) 102M
BERT-wwm 95.1 (94.5) 95.4 (95.0) 102M
BERT-wwm-ext 95.4 (94.6) 95.3 (94.7) 102M
RoBERTa-wwm-ext 95.0 (94.6) 95.6 (94.8) 102M
RBT3 92.8 92.8 38M
ELECTRA-small 92.8 (92.5) 94.3 (93.5) 12M
ELECTRA-180g-small 94.1 93.6 12M
ELECTRA-small-ex 92.6 93.6 25M
ELECTRA-180g-small-ex 92.8 93.4 25M
ELECTRA-base 93.8 (93.0) 94.5 (93.5) 102M
ELECTRA-180g-base 94.3 94.8 102M
ELECTRA-large 95.2 95.3 324M
ELECTRA-180g-large 94.8 95.2 324M

句对分类:LCQMC

以下两个数据集均需要将一个句对进行分类,判断两个句子的语义是否相同(二分类任务)。

LCQMC由哈工大深圳研究生院智能计算研究中心发布。 评价指标为:Accuracy

模型 开发集 测试集 参数量
BERT 89.4 (88.4) 86.9 (86.4) 102M
BERT-wwm 89.4 (89.2) 87.0 (86.8) 102M
BERT-wwm-ext 89.6 (89.2) 87.1 (86.6) 102M
RoBERTa-wwm-ext 89.0 (88.7) 86.4 (86.1) 102M
RBT3 85.3 85.1 38M
ELECTRA-small 86.7 (86.3) 85.9 (85.6) 12M
ELECTRA-180g-small 86.6 85.8 12M
ELECTRA-small-ex 87.5 86.0 25M
ELECTRA-180g-small-ex 87.6 86.3 25M
ELECTRA-base 90.2 (89.8) 87.6 (87.3) 102M
ELECTRA-180g-base 90.2 87.1 102M
ELECTRA-large 90.7 87.3 324M
ELECTRA-180g-large 90.3 87.3 324M

句对分类:BQ Corpus

BQ Corpus由哈工大深圳研究生院智能计算研究中心发布,是面向银行领域的数据集。 评价指标为:Accuracy

模型 开发集 测试集 参数量
BERT 86.0 (85.5) 84.8 (84.6) 102M
BERT-wwm 86.1 (85.6) 85.2 (84.9) 102M
BERT-wwm-ext 86.4 (85.5) 85.3 (84.8) 102M
RoBERTa-wwm-ext 86.0 (85.4) 85.0 (84.6) 102M
RBT3 84.1 83.3 38M
ELECTRA-small 83.5 (83.0) 82.0 (81.7) 12M
ELECTRA-180g-small 83.3 82.1 12M
ELECTRA-small-ex 84.0 82.6 25M
ELECTRA-180g-small-ex 84.6 83.4 25M
ELECTRA-base 84.8 (84.7) 84.5 (84.0) 102M
ELECTRA-180g-base 85.8 84.5 102M
ELECTRA-large 86.7 85.1 324M
ELECTRA-180g-large 86.4 85.4 324M

司法任务效果

我们使用CAIL 2018司法评测的罪名预测数据对司法ELECTRA进行了测试。small/base/large学习率分别为:5e-4/3e-4/1e-4。 评价指标为:Accuracy

模型 开发集 测试集 参数量
ELECTRA-small 78.84 76.35 12M
legal-ELECTRA-small 79.60 77.03 12M
ELECTRA-base 80.94 78.41 102M
legal-ELECTRA-base 81.71 79.17 102M
ELECTRA-large 81.53 78.97 324M
legal-ELECTRA-large 82.60 79.89 324M

使用方法

用户可以基于已发布的上述中文ELECTRA预训练模型进行下游任务精调。 在这里我们只介绍最基本的用法,更详细的用法请参考ELECTRA官方介绍

本例中,我们使用ELECTRA-small模型在CMRC 2018任务上进行精调,相关步骤如下。假设,

  • data-dir:工作根目录,可按实际情况设置。
  • model-name:模型名称,本例中为electra-small
  • task-name:任务名称,本例中为cmrc2018。本目录中的代码已适配了以上六个中文任务,task-name分别为cmrc2018drcdxnlichnsenticorplcqmcbqcorpus

第一步:下载预训练模型并解压

模型下载章节中,下载ELECTRA-small模型,并解压至${data-dir}/models/${model-name}。 该目录下应包含electra_model.*vocab.txtcheckpoint,共计5个文件。

第二步:准备任务数据

下载CMRC 2018训练集和开发集,并重命名为train.jsondev.json。 将两个文件放到${data-dir}/finetuning_data/${task-name}

第三步:运行训练命令

python run_finetuning.py \
    --data-dir ${data-dir} \
    --model-name ${model-name} \
    --hparams params_cmrc2018.json

其中data-dirmodel-name在上面已经介绍。hparams是一个JSON词典,在本例中的params_cmrc2018.json包含了精调相关超参数,例如:

{
    "task_names": ["cmrc2018"],
    "max_seq_length": 512,
    "vocab_size": 21128,
    "model_size": "small",
    "do_train": true,
    "do_eval": true,
    "write_test_outputs": true,
    "num_train_epochs": 2,
    "learning_rate": 3e-4,
    "train_batch_size": 32,
    "eval_batch_size": 32,
}

在上述JSON文件中,我们只列举了最重要的一些参数,完整参数列表请查阅configure_finetuning.py

运行完毕后,

  1. 对于阅读理解任务,生成的预测JSON数据cmrc2018_dev_preds.json保存在${data-dir}/results/${task-name}_qa/。可以调用外部评测脚本来得到最终评测结果,例如:python cmrc2018_drcd_evaluate.py dev.json cmrc2018_dev_preds.json
  2. 对于分类任务,相关accuracy信息会直接打印在屏幕,例如:xnli: accuracy: 72.5 - loss: 0.67

FAQ

Q: 在下游任务精调的时候ELECTRA模型的学习率怎么设置?
A: 我们建议使用原论文使用的学习率作为初始基线(small是3e-4,base是1e-4)然后适当增减学习率进行调试。 需要注意的是,相比BERT、RoBERTa一类的模型来说ELECTRA的学习率要相对大一些。

Q: 有没有PyTorch版权重?
A: 有,模型下载

Q: 预训练用的数据能共享一下吗?
A: 很遗憾,不可以。

Q: 未来计划?
A: 敬请关注。

引用

如果本目录中的内容对你的研究工作有所帮助,欢迎在论文中引用下述论文。

@journal{cui-etal-2021-pretrain,
  title={Pre-Training with Whole Word Masking for Chinese BERT},
  author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing},
  journal={IEEE Transactions on Audio, Speech and Language Processing},
  year={2021},
  url={https://ieeexplore.ieee.org/document/9599397},
  doi={10.1109/TASLP.2021.3124365},
 }
@inproceedings{cui-etal-2020-revisiting,
    title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
    author = "Cui, Yiming  and
      Che, Wanxiang  and
      Liu, Ting  and
      Qin, Bing  and
      Wang, Shijin  and
      Hu, Guoping",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
    month = nov,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
    pages = "657--668",
}

关注我们

欢迎关注哈工大讯飞联合实验室官方微信公众号,了解最新的技术动态。

qrcode.png

问题反馈

Before you submit an issue:

  • You are advised to read FAQ first before you submit an issue.
  • Repetitive and irrelevant issues will be ignored and closed by [stable-bot](stale · GitHub Marketplace). Thank you for your understanding and support.
  • We cannot acommodate EVERY request, and thus please bare in mind that there is no guarantee that your request will be met.
  • Always be polite when you submit an issue.

chinese-electra's People

Contributors

cclauss avatar kinghuin avatar ymcui avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chinese-electra's Issues

模型加载兼容问题

我感觉electra的modeling即是bert的modeling,预训练的优化在于优化目标;我想问下electra的run_finetuning.py能不能直接加载bert,bert-wwm,robertra等等的模型;目前加载似乎不太合适,主要修改那个代码块?

预训练数据量

大神,我想问下,electra-small, electra-large预训练的训练数据大概是多少?

想加载已有模型进一步预训练时缺少adam应该怎么办呢?

谷歌原repo提到对已有模型经一部预训练的方法是将路径改到已有模型上继续运行run_pretraining.py,即
Setting the model-name to point to a downloaded model (e.g., --model-name electra_small if you downloaded weights to $DATA_DIR/electra_small).
但使用中文electra时似乎因为去掉了adam_m而无法进一步预训练
报错信息为:Key discriminator_predictions/dense/bias/adam_m not found in checkpoint

运行TF版本的模型,提示python停止工作

没有GPU,每次运行时python就会提示停止工作
D:\paper\electra>python run_finetuning.py --data-dir \train --model-name electra
_small --hparams lcqmc.json
2020-06-28 19:51:11.486608: W tensorflow/stream_executor/platform/default/dso_lo
ader.cc:55] Could not load dynamic library 'cudart64_100.dll'; dlerror: cudart64
_100.dll not found
2020-06-28 19:51:11.496134: I tensorflow/stream_executor/cuda/cudart_stub.cc:29]
Ignore above cudart dlerror if you do not have a GPU set up on your machine.
WARNING:tensorflow:From D:\paper\electra\model\optimization.py:70: The name tf.t
rain.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

ELECTRA-small, Chinese 权重丢了5个?

提供开源模型的老师,您好,我通过讯飞云下载的模型, 解析ELECTRA-small模型发现有五个权重找不到。
error electra/encoder/layer_9/intermediate/dense/kernel
error electra/encoder/layer_9/output/LayerNorm/beta
error electra/encoder/layer_9/output/LayerNorm/gamma
error electra/encoder/layer_9/output/dense/bias
error electra/encoder/layer_9/output/dense/kernel

请检查是否存在该问题,十分感谢。

换了个源下载是ok的,抱歉

最好还是配上对应的json

解压后发现没有json文件,而不少框架都是根据json文件来读取模型基本结构的,建议还是配上。

比如small版

{
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "max_position_embeddings": 512,
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128,
  "embedding_size": 128
}

还有,不知道为啥ckpt的命名不加上ckpt...

最后,最新版bert4keras(0.6.4)已经能加载electra了,只需要在build_transformer_model里边传入model='electra',欢迎用bert4keras调用哈哈~

我在finetune CMRC 2018时发现预处理得到的数据不及预期

我看了finetune/qa/qa_tasks.py中的代码,似乎与原版的electra一模一样。
官方的代码针对squad2.0的预处理是英文的。英文的预处理可以适配中文的qa数据吗?
我们在运行该仓库的代码finetune CMRC数据时。得到的example大致如下

{'task_name': 'cmrc2018', 'eid': 0, 'qas_id': 'TRAIN_186_QUERY_0', 'qid': None, 'question_text': '范廷颂是什么时候被任为主教的?', 'doc_tokens': ['范廷颂枢机(,),圣名保禄·若瑟(),是越南罗马天主教枢机。1963年被任为主教;1990年被擢升为天主教河内总教区宗座署理;1994年被擢升为总主教,同年年底被擢升为枢机;2009年2月离世。范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生;童年时接受良好教育后,被一位越南神父带到河内继续其学业。范廷颂于1940年在河内大修道院完成神学学业。范廷颂于1949年6月6日在河内的主教座堂晋铎;及后被派到圣女小德兰孤儿院服务。1950年代,范廷颂在河内堂区创建移民接待中心以收容到河内避战的难民。1954年,法越战争结束,越南**共和国建都河内,当时很多天主教神职人员逃至越南的南方,但范廷颂仍然留在河内。翌年管理圣若望小修院;惟在1960年因捍卫修院的自由、自治及拒绝政府在修院设政治课的要求而被捕。1963年4月5日,教宗任命范廷颂为天主教北宁教区主教,同年8月15日就任;其牧铭为「我信天主的爱」。由于范廷颂被越南政府软禁差不多30年,因此他无法到所属堂区进行牧灵工作而专注研读等工作。范廷颂除了面对战争、贫困、被当局**天主教会等问题外,也秘密恢复修院、创建女修会团体等。1990年,教宗若望保禄二世在同年6月18日擢升范廷颂为天主教河内总教区宗座署理以填补该教区总主教的空缺。1994年3月23日,范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理;同年11月26日,若望保禄二世擢升范廷颂为枢机。范廷颂在1995年至2001年期间出任天主教越南主教团主席。2003年4月26日,教宗若望保禄二世任命天主教谅山教区兼天主教高平教区吴光杰主教为天主教河内总教区署理主教;及至2005年2月19日,范廷颂因获批辞去总主教职务而荣休;吴光杰同日真除天主教河内总教区总主教职务。范廷颂于2009年2月22日清晨在河内离世,享年89岁;其葬礼于同月26日上午在天主教河内总教区总主教座堂举行。'], 'orig_answer_text': '1963年', 'start_position': 0, 'end_position': 0, 'is_impossible': False}

该预处理似乎无法正确的进行token的切分以及start、end位置的查找。

IOError in tf.io.TFRecordWriter(output_file)

there is a bug in finetune/preprocessing.py
def _serialize_dataset(self, tasks, is_training, split):

utils.mkdir(tfrecords_path.rsplit("/", 1)[0])
this code used "/" as split flag to split out_file path ,if this code running in widows ,it might make the filename as directory ,then it will make IOError in method "def serialize_examples(self, examples, is_training, output_file, batch_size):" -->with tf.io.TFRecordWriter(output_file) as writer:
PLS use "utils.mkdir(tfrecords_path.rsplit(os.sep,1)[0])" insteadof "utils.mkdir(tfrecords_path.rsplit("/", 1)[0])"

run_pretraining的hparams的参数能共享下吗?

在跑run_pretraining的时候碰到
logits = tf.matmul(hidden, model.get_embedding_table(),
transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)

这行报错,原因是hidden是的一个rank=3的tensor, model.get_embedding_table()是一个等于2的tensor,不知道有没有碰到这样的错误?

Sentiment classification

您好,很感谢您的分享,我想用ELECTRA-large做二分类的情感分类,该如何进行代码复现呢,目前我遇到了问题不知道怎么解决,望您指导,谢谢

download issue

你好,这里的ELECTRA-large, Chinese (new)不能正常下载,可以解决吗

small模型预训练参数

你好,我进行了small模型预训练,然后模型大小跟提供的不太一样,然后调用出现权重形状不兼容的错误,想问下预训练的参数这些,是多少?

error in loading checkpoints for pretraining

error in loading checkpoints for pretraining, adam_m is missing?

2020-08-11 22:40:26.262591: W tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Not found: Key discriminator_predictions/dense/bias/adam_m not found in checkpoint
ERROR:tensorflow:Error recorded from training_loop: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key discriminator_predictions/dense/bias/adam_m not found in checkpoint
	 [[node save/RestoreV2 (defined at run_pretraining.py:363) ]]

Original stack trace for 'save/RestoreV2':
  File "run_pretraining.py", line 404, in <module>
    main()
  File "run_pretraining.py", line 400, in main
    args.model_name, args.data_dir, **hparams))
  File "run_pretraining.py", line 363, in train_or_eval
    max_steps=config.num_train_steps)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1480, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1007, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 725, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1200, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1205, in _create_session
    return self._sess_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 871, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 638, in create_session
    self._scaffold.finalize()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 237, in finalize
    self._saver.build()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 837, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 875, in _build
    build_restore=build_restore)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 502, in _build_internal
    restore_sequentially, reshape)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 381, in _AddShardedRestoreOps
    name="restore_shard"))
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 328, in _AddRestoreOps
    restore_sequentially)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 575, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1696, in restore_v2
    name=name)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

Traceback (most recent call last):
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key discriminator_predictions/dense/bias/adam_m not found in checkpoint
	 [[{{node save/RestoreV2}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1286, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key discriminator_predictions/dense/bias/adam_m not found in checkpoint
	 [[node save/RestoreV2 (defined at run_pretraining.py:363) ]]

Original stack trace for 'save/RestoreV2':
  File "run_pretraining.py", line 404, in <module>
    main()
  File "run_pretraining.py", line 400, in main
    args.model_name, args.data_dir, **hparams))
  File "run_pretraining.py", line 363, in train_or_eval
    max_steps=config.num_train_steps)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1480, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1007, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 725, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1200, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1205, in _create_session
    return self._sess_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 871, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 638, in create_session
    self._scaffold.finalize()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 237, in finalize
    self._saver.build()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 837, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 875, in _build
    build_restore=build_restore)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 502, in _build_internal
    restore_sequentially, reshape)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 381, in _AddShardedRestoreOps
    name="restore_shard"))
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 328, in _AddRestoreOps
    restore_sequentially)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 575, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1696, in restore_v2
    name=name)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1296, in restore
    names_to_keys = object_graph_key_mapping(save_path)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1614, in object_graph_key_mapping
    object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 678, in get_tensor
    return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
tensorflow.python.framework.errors_impl.NotFoundError: Key _CHECKPOINTABLE_OBJECT_GRAPH not found in checkpoint

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "run_pretraining.py", line 404, in <module>
    main()
  File "run_pretraining.py", line 400, in main
    args.model_name, args.data_dir, **hparams))
  File "run_pretraining.py", line 363, in train_or_eval
    max_steps=config.num_train_steps)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2876, in train
    rendezvous.raise_errors()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 131, in raise_errors
    six.reraise(typ, value, traceback)
  File "/home/test/anaconda3/lib/python3.7/site-packages/six.py", line 693, in reraise
    raise value
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1480, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1007, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 725, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1200, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1205, in _create_session
    return self._sess_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 871, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 647, in create_session
    init_fn=self._scaffold.init_fn)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/session_manager.py", line 290, in prepare_session
    config=config)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/session_manager.py", line 220, in _restore_checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1302, in restore
    err, "a Variable name or other graph key that is missing")
tensorflow.python.framework.errors_impl.NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key discriminator_predictions/dense/bias/adam_m not found in checkpoint
	 [[node save/RestoreV2 (defined at run_pretraining.py:363) ]]

Original stack trace for 'save/RestoreV2':
  File "run_pretraining.py", line 404, in <module>
    main()
  File "run_pretraining.py", line 400, in main
    args.model_name, args.data_dir, **hparams))
  File "run_pretraining.py", line 363, in train_or_eval
    max_steps=config.num_train_steps)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1480, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1007, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 725, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1200, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 1205, in _create_session
    return self._sess_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 871, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 638, in create_session
    self._scaffold.finalize()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py", line 237, in finalize
    self._saver.build()
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 837, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 875, in _build
    build_restore=build_restore)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 502, in _build_internal
    restore_sequentially, reshape)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 381, in _AddShardedRestoreOps
    name="restore_shard"))
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 328, in _AddRestoreOps
    restore_sequentially)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 575, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1696, in restore_v2
    name=name)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/test/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

keyError:3200000

Traceback (most recent call last):
File "run_finetuning.py", line 375, in
main()
File "run_finetuning.py", line 371, in main
args.model_name, args.data_dir, **hparams))
File "run_finetuning.py", line 304, in run_finetuning
scorer.write_predictions()
File "/Chinese-ELECTRA-master/finetune/qa/qa_metrics.py", line 113, in write_predictions
result = unique_id_to_result[feature[self._name + "_eid"]]
KeyError: 3200000

在預訓練的時候是否也有使用全詞遮蔽?

之前bert-wwm可以改善原先bert預訓練mask單個字的問題,全詞遮蔽(wwm)可以使模型學到更多詞與詞的關係。
目前這一個版本的electra在預訓練的時候是否也有使用全詞遮蔽(wwm)?

请教,是否观察到 electra 较 bert/roberta 收敛更快?

比较 pretraining 不同 steps 的 checkpoint。同 step 对应的 checkpoint,electra 100% label 学习的优势,在 finetuning 效果上,论文里是显著快于 bert 的。

不知道复现是否有这个结论呢?我们在做一个类似的策略,收敛速度上并没有论文显著。

下游微调任务 CMRC 2018训练集和开发集

按照网站的提示,没有运行成功啊,json格式的数据转换成tfrecord的代码,已经写好了吗?还是需要自己加。是否我只要按照提示格式,放好数据直接运行吗?

对于 readme 的错误或疑问

readme中写的是:task-name:任务名称,本例中为cmrc2018。本目录中的代码已适配了以上六个中文任务,task-name分别为cmrc2018,drcd,xnli,chnsenticorp,lcqmc,bqcorpus。

但是代码里是: if task.name in ["cola", "mrpc", "mnli", "sst", "rte", "qnli", "qqp", "sts"]:

应该要将 readme 里的 xnli 改为 mnli 吧?

关于Loss计算的问题

您好,下游任务损失函数是加和(reduce_sum)计算,而非对Batch求均值(reduce_mean),感觉对部分非Adam优化器的结果会产生影响。想请教一下,这一细节是有意为之还是不会影响模型效果?此外非常感谢共享中文预训练模型。

关于并行

您好,我在8个2080Ti上进行预训练,只有一个GPU显存几乎占满,另外7个都是只是用100M左右,加大batch_size就出现OOM问题了,请问代码中,有设置GPU使用的超参数吗,看了一遍没有找到。

关于在领域数据集上继续预训练后微调效果很差的问题

你好,我在提供的base模型基础上用自己的数据继续预训练大约100w步,生成器loss 0.9左右,判别器loss 0.18左右,之后用预训练的模型在分类任务上微调,准确率只有10%+,而且用训练集测试准确率也只有30%。如果直接在提供的base模型基础上准确率能到90%,这是什么原因呢

关于ELECTRA的预训练方法

ELECTRA由生成器+判决器构成,生成器负责把[MASK]替换成实际tokens,判决器负责区分替换结果是否和实际data中相同

但是,论文提及:

  1. Typically k = [0.15n], i.e., 15% of the tokens are masked out

  2. if the generator happens to generate the correct token, that token is considered “real” instead of “fake

也就是说,只有 (1 - generator_inference_acc) * 0.15的token会被预测成fake,训练到后期就是一个极度不均衡的二分类问题(对于判决器而言),为什么判决器不会受到影响呢?

请问finetune时应如何设置token type id?

在Bert中若处理输入为两个句子的相关任务(例如语义相似性打分等),常使用token_type_embedding对两个句子分别加上不同的embedding;这一做法只需要在transformers的API中设置token_type_id(一个句子全为0,另一个全为1)即可实现。
然而,electra的预训练好像取消了NSP任务,相应的也就没有训练这个句子embedding(抱歉我不是很确定,只是看了一下论文好像没写这一点😂),所以我想请教一下使用token_type_id这一做法是否在electra中也可以通用呢?如果不行,对于两个句子的输入,推荐的处理方法是什么呢?谢谢!

huggingface的tokenizer问题

首先感谢作者所做的工作,我在使用过程中有两个疑问

  1. 请问ELECTRA可以使用huggingface的bert-tokenizer吗?词表完全相同是否能用相同的tokenizer?
  2. 请问Chinese-ELECTRA的预训练语料有多大呢?

多分类任务如何修改?

如果进行多分类情感分析的话,修改哪里呢?直接将CnSentiCorp任务里的["0","1"]修改为["1","2","3","4","5"]吗?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.