Comments (2)
Is this an expected behavior? or could there be bugs?
Mistral loads in bfloat16
by default. I noticed this caused issues in Colab with the JAX backbend (TensorFlow runs fine though with 16.5 GB of RAM).
The workaround to use the dtype set using keras.mixed_precision
module is to pass dtype=None
to the from_preset
method:
mistral_lm = keras_nlp.models.MistralCausalLM.from_preset('mistral_instruct_7b_en', preprocessor=preprocessor, dtype=None)
Let me know if this lowers the RAM usage. This will be fixed in the next release.
from keras-nlp.
Thanks for the bug! Just synced up with @tirthasheshpatel. We want to change two things here
- By default, mistral should follow global keras default settings. So
keras.mixed_precision.set_global_policy("mixed...")
-> variables load as float32.keras.config.set_floatx("bfloat16")
-> variables load at bfloat16. - There is a bug with the jax backend only, where generation for mistral is consuming significantly too much CPU and GPU memory. It's a one liner fix on our side I think.
These are both simple but important fixes, we should have a patch fix for this in a couple days. Thanks @deep-diver!
from keras-nlp.
Related Issues (20)
- Gemma discrepancies HOT 1
- Dropout is not called in the training regime in TransformerEncoder and others HOT 2
- Data-Parallel Training with KerasNLP and tf.distribute example dataset problem HOT 4
- Feature Request: Transformer Debugger - Debugging and controlling the behavior of transformer based LLM models. HOT 3
- Add Mistral 0.2 models as possible presets HOT 3
- keras-nlp insists I use the (buggy) Tensorflow 2.16.1 which does not work with my GPU HOT 12
- `SentencePieceTokenizer` inside a `keras.models.Model` fails to be reconstructed during `keras.saving.load_model()` HOT 2
- Add grok-1
- [RfC] Ideas for better Hugging Face Hub integration HOT 7
- Any plans for QLora? HOT 2
- create local variable per_token_loss in score method to global. So that we can modify loss function. HOT 4
- Why not use low precision matmul for reverse embedding in gemma model HOT 4
- Keep kv cache as list of tensors maybe better than one tensor HOT 3
- Issue when fine-tuning Albert - Resource localhost/_0_SentencepieceOp/N10tensorflow4text12_GLOBAL__N_121SentencepieceResourceE does not exist. HOT 3
- Issue instantiating a keras_nlp.models.Backbone from a model preset of Hugging Face handles HOT 4
- How gemma_lm.preprocessor.sequence_length dealing with large input data HOT 3
- Any plans for Llama 3?
- Any plans for moreLlama 3?
- Any plans for more Llama type models? HOT 1
- Samplers in Gemma model HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-nlp.