Giter Site home page Giter Site logo

bert4keras3's Introduction

bert4keras3

背景

bert4keras是我最喜欢的库之一,但在现在来说其后端tf1.15显得有点落后,因此本库目的在实现对其的升级

目的

兼容keras3及其对应后端 目前已经成功实现了bert4keras所支持的所有预训练模型的兼容
bert4keras实现的优化器目前暂时不做兼容,除开优化器部分外,如何使用请参考bert4keras的example,本仓库的example只提供了如何把模型load出来的测试

api文档

请参考api说明

安装

因为我是个人开发,连草台班子都不是,经常会发布修改bug的版本,所以建议安装最新版本

pip install --upgrade bert4keras3

后端安装

如果你用不是tensorflow后端,我建议安装一个tensorflow-cpu==2.10

pip3 install tensorflow-cpu==2.10
pip3 install --upgrade keras

如果你用torch后端,直接安装最新的torch就行了。但是我个人建议torch后端只用来调试

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip3 install keras

如果你需要使用tensorflow后端,那我建议你安装tensorflow的2.15

pip3 install tensorflow[and-cuda]==2.15
pip3 install --upgrade keras

当然你想安装最新的也可以,但是问题就是加载苏神的权重会有点问题。谷歌的尿性你们懂的
还有就是cuda版本要大于12.2,你的服务器不一定能同步。可以看tensorflow的cuda、cudnn版本对应
如果你想使用jax后端,jax安装建议看keras官方文档的jax-cuda要求
比如在keras3.3.3的情况下,官方推荐的版本是jax 0.4.23,那安装可以这么写

#cuda12
pip3 install jax[cuda12_pip]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#cuda11
pip3 install jax[cuda11_pip]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

pip3 install --upgrade keras

jax和tensorflow后端只能在linux使用cuda

功能

初始版本与bert4keras基本相同,可以参考https://github.com/bojone/bert4keras
但需要注意的是,如果bert4keras的example中必须要要tf.keras的,在本库中依然需要
如果你需要使用tf之外的其他后端,需要修改bert4keras中的tf api
由于优化器部分维护工作量过大,本库放弃了对器优化器的维护。并且以后如果推出优化器功能,只keras3版本
目前keras3支持原生梯度累积、ema,AdamW等,如果需要什么keras不支持的功能欢迎提issue
除此之外重计算/gradient_checkpoint功能目前依然不支持keras3

如何实现多版本兼容

如果你只是想兼容torch、tf和jax,那么我建议你使用纯keras的api实现,参考keras.io。对于精细的算子可以使用keras.ops,如果keras实在没有算子,那你只能提供一个api的三后端实现了
如果你想兼容keras2和tf.api,因为在keras3中增加了ops系列并且删除了绝大部分keras.backend中的算子操作。因此如果你需要兼容tf2是有一定困难的。
为了解决这个问题,bert4keras3.ops手动对齐了keras3中的ops,api。所以如果你想要兼容keras2和tf.keras,那么在编写代码时请from bert4keras3 import ops,在keras2中使用的是我们对齐的api,而在keras3中使用的是keras.ops。通过这种方法,你可以很容易地实现更好的兼容性

权重

兼容bert4keras支持加载的权重,你可以在本来bert4keras支持的tf.keras、tf1.15-tf2.15和keras3加载:

模型分类 模型名称 权重链接 支持kv-cache 生成
bert/roberta Google原版bert github
brightmart版roberta github
哈工大版roberta github
追一开源bert github
LaBSE(多国语言BERT) github
albert 谷歌albert github x
brightmart版albert github x
苏神转换后的albert github x
NEZHA 双向NEZHA github x
单向NEZHA github x
T5 谷歌T5 github
MT5 github
苏神T5-pegasus github
T5.1.1 github
ELECTRA Google原版ELECTRA github x
哈工大版ELECTRA github x
CLUE版ELECTRA github x
GPT-oai GPT_OpenAI github x
GPT2-ML GPT2-ML github x
GAU GAU-ALPHA github x
Roformer 苏神原版roformer github
roformer-sim github
Roformerv2 苏神原版roformer-v2 github

bert4keras3的新增加的模型权重,不再使用ckpt存储

通过build_transformer_model( keras_weights_path='xx.weights.h5')方法读取权重,只能使用keras3加载

模型分类 模型名称 权重链接 数据类型 分词器
T5.1.1 ChatYuan 百度网盘 FP32 SpTokenizer
Flan-T5-small 百度网盘 FP32 SpTokenizer
Flan-T5-base 百度网盘 FP32 SpTokenizer
Flan-T5-large 百度网盘 FP32 SpTokenizer
Flan-T5-xl 百度网盘 FP32 SpTokenizer
MT5-large 百度网盘 FP32 SpTokenizer
UMT5-small 百度网盘 FP32 SpTokenizer
UMT5-base 百度网盘 FP32 SpTokenizer
UMT5-xl 百度网盘 FP32 SpTokenizer
Gemma Gemma-2b 百度网盘 BF16 SpTokenizer
Gemma-2b-Code 百度网盘 BF16 SpTokenizer
Gemma-2b-it 百度网盘 BF16 SpTokenizer
Gemma1.1-2b-it 百度网盘 BF16 SpTokenizer
Gemma-7b 百度网盘 BF16 SpTokenizer
Gemma-7b-Code 百度网盘 BF16 SpTokenizer
Gemma-7b-it 百度网盘 BF16 SpTokenizer
Gemma1.1-7b-it 百度网盘 BF16 SpTokenizer
Gemma-7b-it-Code 百度网盘 BF16 SpTokenizer
Llama Yi-6B 百度网盘 BF16 AutoTokenizer
Yi-6B-it 百度网盘 BF16 AutoTokenizer
Yi-9B 百度网盘 BF16 AutoTokenizer
Yi-1.5-6B 百度网盘 BF16 AutoTokenizer
Yi-1.5-9B 百度网盘 BF16 AutoTokenizer
Llama3-8B 百度网盘 BF16 AutoTokenizer
Llama3-8B-it 百度网盘 BF16 AutoTokenizer
千问 Qwen-0.5B 百度网盘 BF16 AutoTokenizer
Qwen-0.5B-it 百度网盘 BF16 AutoTokenizer
Qwen-1.8B 百度网盘 BF16 AutoTokenizer
Qwen-1.8B-it 百度网盘 BF16 AutoTokenizer
Qwen-4B 百度网盘 BF16 AutoTokenizer
Qwen-4B-it 百度网盘 BF16 AutoTokenizer
Qwen-7B 百度网盘 BF16 AutoTokenizer
Qwen-7B-it 百度网盘 BF16 AutoTokenizer
Qwen-14B 百度网盘 BF16 AutoTokenizer
Qwen-14B-it 百度网盘 BF16 AutoTokenizer

注意事项

  • 注1:brightmart版albert的开源时间早于Google版albert,这导致早期brightmart版albert的权重与Google版的不完全一致,换言之两者不能直接相互替换。为了减少代码冗余,bert4keras的0.2.4及后续版本均只支持加载Google版以brightmart版中带Google字眼的权重。如果要加载早期版本的权重,请用0.2.3版本,或者考虑作者转换过的albert_zh。(苏神注)
  • 注2:下载下来的ELECTRA权重,如果没有json配置文件的话,参考这里自己改一个(需要加上type_vocab_size字段)。(苏神注)
  • 注3: 模型分类这里会跳转到使用的example
  • 注4:SpTokenizer指的是bert4keras3.tokenizers.SpTokenizer,AutoTokenizer指的是transformers的分词器。用法不同需要注意
  • 注5:因为不能转换全部的权重,所以我提供了转化权重的脚本,有需要自己去转。
  • 注6:bert4keras3的新增加的模型权重均支持kv-cache生成
  • 注7: it模型指的是instruct模型,也就是我们俗话说的chat模型

版本更新

2023.12.30发布bert4keras3的第一个版本1.0

对bert4keras除优化器部分外的升级,实现对tensorflow,jax,torch的多后端兼容

1.31号更新,发布1.1版本

转换了chatyuan模型权重(基于t5模型)
更新了支持批量运算的t5-cache推理版本,详细使用参考t5-cache的使用example 。里面较为详细地列出了cache模型要如何使用。
除了T5,还增加了bertroformer/roformer-v2的cache支持,用法和t5一样,example里只是测试一下与greedy是否一致

3.17号更新,发布1.2版本

增加了对weights.h5的读取支持
增加了lora支持,可以通过设置os.environ["ENABLE_LORA"]='1' 启动lora训练,注意的是除了lora之外的参数全部会被冻结
增加了flash-attention支持,可以通过设置os.environ["FLASH_ATTN"]='1'使用flash-attention
但是需要注意的是,tensorflow不支持。而jax在https://github.com/nshepperd/flash_attn_jax/releases 下载,torch则是 https://github.com/Dao-AILab/flash-attention

4.25号更新,发布1.3版本

重新整理了苏神的代码,更新了对 Gemma,Qwen,和llama系列模型(llama3和Yi)的支持,转换了UMT5,FlanT5的权重,并且提供了转换脚本,大家可以自行转换权重

bert4keras3's People

Contributors

pass-lin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.