I'm trying to reproduce the results from table 2 of your paper.
First, if I understand correctly, there is a bit of a mismatch between your implementation and the TOVA implementation from https://arxiv.org/abs/2401.06104. While current implementation is:
elif mode == 'tova':
eviction_ids = torch.topk(cache_attn_scores[:, :, sink_length:-recent_window], dim=-1, k=stride, largest=False)[1] + sink_length
elif mode == 'tova':
eviction_ids = torch.topk(cache_attn_scores[:, :, :], dim=-1, k=stride, largest=False)[1]
After fixing that, I tried to reproduce the table 2 results from the paper for 'TOVA', 'H2O' and RoCo'.
Here are the results I got:
tova
{'rouge1': 0.31296021280547465, 'rouge2': 0.11589496163640313, 'rougeL': 0.19914965472944104, 'rougeLsum': 0.20350149974062479, 'bleu': 0.042805674847994496}
h2o_head
{'rouge1': 0.3168426756016529, 'rouge2': 0.11983298042885819, 'rougeL': 0.20188195228157899, 'rougeLsum': 0.20641862075377534, 'bleu': 0.045165272142095016}
roco
{'rouge1': 0.3180583914147622, 'rouge2': 0.12269585471287722, 'rougeL': 0.2033885685550363, 'rougeLsum': 0.20759446503031692, 'bleu': 0.0469378196052233}
It seems that my results are far off from what reported in the paper, and I wonder why there is such a gap.
Below, I attach the full script (based on test_summarization.py file) for reproducing the results:
import warnings
warnings.filterwarnings("ignore")
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer)
from easykv import enable_fixed_kv
from multiprocessing import Pool
from datasets import load_dataset, load_metric
from tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader
def compute_rouge(predictions, references):
metric = load_metric('rouge')
scores = metric.compute(predictions=predictions, references=references)
scores = {k: v.high.fmeasure for k,v in scores.items()}
return scores
def compute_bleu(predictions, references):
metric = load_metric('bleu')
scores = metric.compute(predictions=[p.split() for p in predictions], references=[[r.split()] for r in references])
return scores['bleu']
def main():
accelerator = Accelerator()
path = "meta-llama/Llama-2-7b-chat-hf"
template = "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{inst}[/INST]"
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path)
dataset = load_dataset("cnn_dailymail", '1.0.0', split='test')
eval_dataloader = DataLoader(dataset)
eval_dataloader = accelerator.prepare(eval_dataloader)
all_results = {}
accelerator.wait_for_everyone()
for kv_policy in ['tova', 'h2o_head', 'roco']:
for budget in [0.5]:
gens = []
refs = []
for sample in tqdm(eval_dataloader, disable=not accelerator.is_main_process):
article = sample['article']
highlight = sample['highlights']
# Define sampling parameters
gen_kwargs = dict(
max_new_tokens=256,
budget=budget,
kv_policy=kv_policy,
keep_attention=True,
do_sample=False,
temperature=1e-11,
top_p=1e-11,
temp_length=0,
recent_ratio=0.5,
)
prompt = f"Write a SHORT summary of the following text delimited by triple backticks. Return your response which covers the key points of the text.\n```{article}```"
input_prompt = template.format(inst=prompt)
input_ids = tokenizer([input_prompt], return_tensors='pt').input_ids.to(model.device)
enable_fixed_kv(model, tokenizer, mode='encoding', stride=min(input_ids.shape[-1]-1, 512))
output = model.easykv_generate(input_ids=input_ids, generation_config=gen_kwargs)
gens.append(output)
refs.append(highlight[0])
accelerator.wait_for_everyone()
gens = accelerator.gather_for_metrics(gens)
refs = accelerator.gather_for_metrics(refs)
if accelerator.is_main_process:
rouge_results = compute_rouge(gens, refs)
rouge_results['bleu'] = compute_bleu(gens, refs)
print(kv_policy)
print(rouge_results)
all_results[kv_policy] = rouge_results
if accelerator.is_main_process:
for key in all_results:
print(key)
print(all_results[key])
if __name__ == '__main__':
main()