from transformers import AutoTokenizer
from transformers import (AutoTokenizer, AutoConfig, LlamaForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments)
from transformers import Trainer
from typing import Optional
import torch
from mamba_former.main import MambaFormer
from datasets import *
from torch.utils.data import *
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-fp16")
print(len(tokenizer))
# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integrers
# Model
model = MambaFormer(
dim = 128,
num_tokens = len(tokenizer),
depth = 2,
d_state = 128,
d_conv = 128,
heads = 8,
dim_head = 64,
return_tokens = True
)
# Forward
out = model(x)
print(out)
print(out.shape)
# count parameters
model_size = sum(t.numel() for t in model.parameters())
print(f"parameter size: {model_size/1000**2:.1f}M parameters")
import datasets
#tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
tokenized_data = load_dataset("xz56/openwebtext-tokenized-small")
print(tokenized_data)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
output_path = "outputs"
args = TrainingArguments(
output_dir=output_path,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
evaluation_strategy="steps",
eval_steps=0.05,
logging_steps=100,
gradient_accumulation_steps=2,
num_train_epochs=1,
weight_decay=0.01,
warmup_steps=0.1,
lr_scheduler_type="cosine",
learning_rate=1.5e-3,
save_steps=0.25,
fp16=True,
report_to="none"
)
print("Train dataset size:", len(tokenized_data['train']))
print("Test dataset size:", len(tokenized_data['test']))
from transformers import Trainer
from torch.utils.data import RandomSampler, SequentialSampler
class CustomTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
else:
return RandomSampler(self.train_dataset)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
return None
else:
return SequentialSampler(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
trainer = CustomTrainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=data_collator,
train_dataset=tokenized_data["train"],
eval_dataset=tokenized_data["test"],
)
trainer.train()
trainer.save_model(f"{output_path}/final_model")
Traceback (most recent call last):
File "/notebooks/main.py", line 118, in <module>
trainer.train()
File "/usr/local/lib/python3.9/dist-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/usr/local/lib/python3.9/dist-packages/transformers/trainer.py", line 1928, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "/usr/local/lib/python3.9/dist-packages/accelerate/data_loader.py", line 452, in __iter__
current_batch = next(dataloader_iter)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = self.dataset.__getitems__(possibly_batched_index)
File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2814, in __getitems__
batch = self.__getitem__(keys)
File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2810, in __getitem__
return self._getitem(key)
File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2794, in _getitem
pa_subtable = query_table(self._data, key, indices=self._indices)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 583, in query_table
_check_valid_index_key(key, size)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
_check_valid_index_key(int(max(key)), size=size)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 4280187 is out of bounds for size 0