Transformer models with arbitrary graph attention patterns in JAX.
TODO list:
- training code
- instantiate the model
- jit + evaluate the model
- training loop
- training working with long inputs
- update training loop
- add attention patterns
- fully connected attn pattern
- window attn pattern
- longformer-style attn pattern
- add dilation support
- add block window support
- dependency graph attn pattern
- constituency graph attn pattern
- optimizations
- add a more efficient global / local attention computation
- accept attention pattern defined per layer or share between layers (tieing)
- tie the positional embeddings between layers
- support LoRA fine-tuning (lorax)
- unit tests