Giter Site home page Giter Site logo

gau-alpha-pytorch's Introduction

GAU-alpha-pytorch

pytorch版本https://github.com/ZhuiyiTechnology/GAU-alpha

News

  • 2022/05/11 感谢苏神提醒,添加了一个注释,其中RoFormerV2*表示未经多任务学习的RoFormerV2模型。
  • 2022/04/22 初始化仓库,添加初步的代码, 添加paddle版本gau_alpha代码。
  • 2022/04/30 添加CLUE分类任务代码。

Install

pip install git+https://github.com/JunnYu/GAU-alpha-pytorch.git
or
pip install gau_alpha

精度对齐

python compare.py
# bert4keras vs pytorch
# mean diff : tensor(6.9320e-07)
# max diff : tensor(3.9101e-05)

torch版本使用

依赖:

  • torch
  • transformers
import torch
from gau_alpha import GAUAlphaForMaskedLM, GAUAlphaTokenizer

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = GAUAlphaTokenizer.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)
pt_model = GAUAlphaForMaskedLM.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)
pt_model.eval()
pt_inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val, idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v, t in zip(val.cpu(), tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
        )
print(pt_outputs_sentence)
# pytorch: 今天[天+0.8657||气+0.0535||阳+0.0165||,+0.0126||晴+0.0111]很好,我[要+0.4619||想+0.4352||又+0.0252||就+0.0157||跑+0.0064]去公园玩。

Paddle版本使用

依赖:

  • paddlepaddle>=2.2.0
  • paddlenlp
import paddle
from transformers import BertTokenizer as GAUAlphaTokenizer
from gau_alpha_paddle import GAUAlphaForMaskedLM

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = GAUAlphaTokenizer.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)
pd_model = GAUAlphaForMaskedLM.from_pretrained("chinese_GAU-alpha-char_L-24_H-768")
pd_model.eval()
pd_inputs = tokenizer(text)
pd_inputs = {k: paddle.to_tensor([v]) for k, v in pd_inputs.items()}

with paddle.no_grad():
    pd_outputs = pd_model(**pd_inputs)[0][0]

pd_outputs_sentence = "paddle: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val, idx = paddle.nn.functional.softmax(pd_outputs[i], -1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v, t in zip(val.cpu(), tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pd_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pd_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
        )
print(pd_outputs_sentence)
# paddle: 今天[天+0.8657||气+0.0535||阳+0.0165||,+0.0126||晴+0.0111]很好,我[要+0.4619||想+0.4352||又+0.0252||就+0.0157||跑+0.0064]去公园玩。

介绍

评测对比

CLUE-dev榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
BERT 60.06 56.80 72.41 79.56 73.93 78.62 83.93
RoBERTa 60.64 58.06 74.05 81.24 76.00 87.50 84.50
RoFormer 60.91 57.54 73.52 80.92 76.07 86.84 84.63
RoFormerV2* 60.87 56.54 72.75 80.34 75.36 80.92 84.67
GAU-α 61.41 57.76 74.17 81.82 75.86 79.93 85.67
RoFormerV2-pytorch 62.87 59.03 76.20 80.85 79.73 87.82 91.87
GAU-α-pytorch(Adafactor) 61.18 57.52 73.42 80.91 75.69 80.59 85.5
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.68 57.95 73.08 81.02 75.36 81.25 83.93

CLUE-test榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
RoFormerV2-pytorch 63.15 58.24 75.42 80.59 74.17 83.79 83.73
GAU-α-pytorch(Adafactor) 61.38 57.08 74.05 80.37 73.53 74.83 85.6
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.54 57.67 72.44 80.32 72.97 76.55 84.13

CLUE-dev集榜单阅读理解和NER结果

cmrc2018 c3 chid cluener
BERT 56.17 60.54 85.69 79.45
RoBERTa 56.54 67.66 86.71 79.47
RoFormer 56.26 67.24 86.57 79.72
RoFormerV2* 57.91 64.62 85.09 81.08
GAU-α 58.09 68.24 87.91 80.01

注:

  • 其中RoFormerV2*表示的是未进行多任务学习的RoFormerV2模型,该模型苏神并未开源,感谢苏神的提醒。
  • 其中不带有pytorch后缀结果都是从GAU-alpha仓库复制过来的。
  • 其中带有pytorch后缀的结果都是自己训练得出的。

引用

Bibtex:

@techreport{gau-alpha,
  title={GAU-α: GAU-based Transformers for NLP - ZhuiyiAI},
  author={Jianlin Su, Shengfeng Pan, Bo Wen, Yunfeng Liu},
  year={2022},
  url="https://github.com/ZhuiyiTechnology/GAU-alpha",
}

Tips:

  • 感谢苏神提供的模型和代码!

gau-alpha-pytorch's People

Contributors

junnyu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.