bert4keras是我最喜欢的库之一,但在现在来说其后端tf1.15显得有点落后,因此本库目的在实现对其的升级
兼容keras3及其对应后端 目前已经成功实现了bert4keras所支持的所有预训练模型的兼容
bert4keras实现的优化器目前暂时不做兼容,除开优化器部分外,如何使用请参考bert4keras的example,本仓库的example只提供了如何把模型load出来的测试
请参考api说明
因为我是个人开发,连草台班子都不是,经常会发布修改bug的版本,所以建议安装最新版本
pip install --upgrade bert4keras3
如果你用不是tensorflow后端,我建议安装一个tensorflow-cpu==2.10
pip3 install tensorflow-cpu==2.10
pip3 install --upgrade keras
如果你用torch后端,直接安装最新的torch就行了。但是我个人建议torch后端只用来调试
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip3 install keras
如果你需要使用tensorflow后端,那我建议你安装tensorflow的2.15
pip3 install tensorflow[and-cuda]==2.15
pip3 install --upgrade keras
当然你想安装最新的也可以,但是问题就是加载苏神的权重会有点问题。谷歌的尿性你们懂的
还有就是cuda版本要大于12.2,你的服务器不一定能同步。可以看tensorflow的cuda、cudnn版本对应
如果你想使用jax后端,jax安装建议看keras官方文档的jax-cuda要求
比如在keras3.3.3的情况下,官方推荐的版本是jax 0.4.23,那安装可以这么写
#cuda12
pip3 install jax[cuda12_pip]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#cuda11
pip3 install jax[cuda11_pip]==0.4.23 --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install --upgrade keras
jax和tensorflow后端只能在linux使用cuda
初始版本与bert4keras基本相同,可以参考https://github.com/bojone/bert4keras
但需要注意的是,如果bert4keras的example中必须要要tf.keras的,在本库中依然需要
如果你需要使用tf之外的其他后端,需要修改bert4keras中的tf api
由于优化器部分维护工作量过大,本库放弃了对器优化器的维护。并且以后如果推出优化器功能,只keras3版本
目前keras3支持原生梯度累积、ema,AdamW等,如果需要什么keras不支持的功能欢迎提issue
除此之外重计算/gradient_checkpoint功能目前依然不支持keras3
如果你只是想兼容torch、tf和jax,那么我建议你使用纯keras的api实现,参考keras.io。对于精细的算子可以使用keras.ops,如果keras实在没有算子,那你只能提供一个api的三后端实现了
如果你想兼容keras2和tf.api,因为在keras3中增加了ops系列并且删除了绝大部分keras.backend中的算子操作。因此如果你需要兼容tf2是有一定困难的。
为了解决这个问题,bert4keras3.ops手动对齐了keras3中的ops,api。所以如果你想要兼容keras2和tf.keras,那么在编写代码时请from bert4keras3 import ops,在keras2中使用的是我们对齐的api,而在keras3中使用的是keras.ops。通过这种方法,你可以很容易地实现更好的兼容性
模型分类 | 模型名称 | 权重链接 | 支持kv-cache 生成 |
---|---|---|---|
bert/roberta | Google原版bert | github | √ |
brightmart版roberta | github | √ | |
哈工大版roberta | github | √ | |
追一开源bert | github | √ | |
LaBSE(多国语言BERT) | github | √ | |
albert | 谷歌albert | github | x |
brightmart版albert | github | x | |
苏神转换后的albert | github | x | |
NEZHA | 双向NEZHA | github | x |
单向NEZHA | github | x | |
T5 | 谷歌T5 | github | √ |
MT5 | github | √ | |
苏神T5-pegasus | github | √ | |
T5.1.1 | github | √ | |
ELECTRA | Google原版ELECTRA | github | x |
哈工大版ELECTRA | github | x | |
CLUE版ELECTRA | github | x | |
GPT-oai | GPT_OpenAI | github | x |
GPT2-ML | GPT2-ML | github | x |
GAU | GAU-ALPHA | github | x |
Roformer | 苏神原版roformer | github | √ |
roformer-sim | github | √ | |
Roformerv2 | 苏神原版roformer-v2 | github | √ |
模型分类 | 模型名称 | 权重链接 | 数据类型 | 分词器 |
---|---|---|---|---|
T5.1.1 | ChatYuan | 百度网盘 | FP32 | SpTokenizer |
Flan-T5-small | 百度网盘 | FP32 | SpTokenizer | |
Flan-T5-base | 百度网盘 | FP32 | SpTokenizer | |
Flan-T5-large | 百度网盘 | FP32 | SpTokenizer | |
Flan-T5-xl | 百度网盘 | FP32 | SpTokenizer | |
MT5-large | 百度网盘 | FP32 | SpTokenizer | |
UMT5-small | 百度网盘 | FP32 | SpTokenizer | |
UMT5-base | 百度网盘 | FP32 | SpTokenizer | |
UMT5-xl | 百度网盘 | FP32 | SpTokenizer | |
Gemma | Gemma-2b | 百度网盘 | BF16 | SpTokenizer |
Gemma-2b-Code | 百度网盘 | BF16 | SpTokenizer | |
Gemma-2b-it | 百度网盘 | BF16 | SpTokenizer | |
Gemma1.1-2b-it | 百度网盘 | BF16 | SpTokenizer | |
Gemma-7b | 百度网盘 | BF16 | SpTokenizer | |
Gemma-7b-Code | 百度网盘 | BF16 | SpTokenizer | |
Gemma-7b-it | 百度网盘 | BF16 | SpTokenizer | |
Gemma1.1-7b-it | 百度网盘 | BF16 | SpTokenizer | |
Gemma-7b-it-Code | 百度网盘 | BF16 | SpTokenizer | |
Llama | Yi-6B | 百度网盘 | BF16 | AutoTokenizer |
Yi-6B-it | 百度网盘 | BF16 | AutoTokenizer | |
Yi-9B | 百度网盘 | BF16 | AutoTokenizer | |
Yi-1.5-6B | 百度网盘 | BF16 | AutoTokenizer | |
Yi-1.5-9B | 百度网盘 | BF16 | AutoTokenizer | |
Llama3-8B | 百度网盘 | BF16 | AutoTokenizer | |
Llama3-8B-it | 百度网盘 | BF16 | AutoTokenizer | |
千问 | Qwen-0.5B | 百度网盘 | BF16 | AutoTokenizer |
Qwen-0.5B-it | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-1.8B | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-1.8B-it | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-4B | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-4B-it | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-7B | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-7B-it | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-14B | 百度网盘 | BF16 | AutoTokenizer | |
Qwen-14B-it | 百度网盘 | BF16 | AutoTokenizer |
注意事项
- 注1:brightmart版albert的开源时间早于Google版albert,这导致早期brightmart版albert的权重与Google版的不完全一致,换言之两者不能直接相互替换。为了减少代码冗余,bert4keras的0.2.4及后续版本均只支持加载Google版以brightmart版中带Google字眼的权重。如果要加载早期版本的权重,请用0.2.3版本,或者考虑作者转换过的albert_zh。(苏神注)
- 注2:下载下来的ELECTRA权重,如果没有json配置文件的话,参考这里自己改一个(需要加上
type_vocab_size
字段)。(苏神注) - 注3: 模型分类这里会跳转到使用的example
- 注4:SpTokenizer指的是bert4keras3.tokenizers.SpTokenizer,AutoTokenizer指的是transformers的分词器。用法不同需要注意
- 注5:因为不能转换全部的权重,所以我提供了转化权重的脚本,有需要自己去转。
- 注6:bert4keras3的新增加的模型权重均支持kv-cache生成
- 注7: it模型指的是instruct模型,也就是我们俗话说的chat模型
对bert4keras除优化器部分外的升级,实现对tensorflow,jax,torch的多后端兼容
转换了chatyuan模型权重(基于t5模型)
更新了支持批量运算的t5-cache推理版本,详细使用参考t5-cache的使用example 。里面较为详细地列出了cache模型要如何使用。
除了T5,还增加了bert和
roformer/roformer-v2的cache支持,用法和t5一样,example里只是测试一下与greedy是否一致
增加了对weights.h5的读取支持
增加了lora支持,可以通过设置os.environ["ENABLE_LORA"]='1' 启动lora训练,注意的是除了lora之外的参数全部会被冻结
增加了flash-attention支持,可以通过设置os.environ["FLASH_ATTN"]='1'使用flash-attention
但是需要注意的是,tensorflow不支持。而jax在https://github.com/nshepperd/flash_attn_jax/releases 下载,torch则是 https://github.com/Dao-AILab/flash-attention
重新整理了苏神的代码,更新了对 Gemma,Qwen,和llama系列模型(llama3和Yi)的支持,转换了UMT5,FlanT5的权重,并且提供了转换脚本,大家可以自行转换权重