Hi, I am trying to use LLM-Pruner on Baichuan-13B model (https://github.com/baichuan-inc/Baichuan-13B). It is also llama structured so I thought it should work instantly, but I got some errors... I am still trying to debug, but slowly... Any help or advice would be very appreciated!
Specifically, I ran "CUDA_VISIBLE_DEVICES=0,1 python hf_prune_baichuan.py --base_model models/baichuan-13b-chat --pruning_ratio 0.25 --device cpu --eval_device cuda --block_wise --block_mlp_layer_start 4 --block_mlp_layer_end 30 --block_attention_layer_start 4 --block_attention_layer_end 30 --save_ckpt_log_name baichuan_13b_chat_0.2 --pruner_type taylor --test_after_train --taylor param_first --save_model",
and I got the following output:
Loading checkpoint shards: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 3/3 [02:14<00:00, 44.74s/it]
2023-07-17 02:29:09 - INFO : Use taylor pruner...
2023-07-17 02:29:09 - INFO : Pruning Attention Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
2023-07-17 02:29:09 - INFO : Pruning MLP Layer = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/dependency.py:362: UserWarning: Unwrapped parameters detected: ['model.layers.10.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.34.input_layernorm.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.38.post_attention_layernorm.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.33.post_attention_layernorm.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.35.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.32.input_layernorm.weight', 'model.layers.39.post_attention_layernorm.weight', 'model.layers.34.post_attention_layernorm.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.36.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.35.post_attention_layernorm.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.37.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.32.post_attention_layernorm.weight', 'model.norm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.36.post_attention_layernorm.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.38.input_layernorm.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.33.input_layernorm.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.37.post_attention_layernorm.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.39.input_layernorm.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight'].
Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.
warnings.warn("Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]))
2023-07-17 02:30:02 - INFO : Start Pruning
2023-07-17 02:30:02 - WARNING : Found cached dataset bookcorpus (/dfs/data/data/bookcorpus/bookcorpus/plain_text/1.0.0/eddee3cae1cc263a431aa98207d4d27fd8a73b0a9742f692af0e6c65afa4d75f)
2023-07-17 02:30:45 - INFO : Start Backwarding in iterative steps = 0...
2023-07-17 02:33:56 - INFO : Loss = 3.644896984100342
Traceback (most recent call last):
File "hf_prune_baichuan.py", line 299, in
main(args)
File "hf_prune_baichuan.py", line 136, in main
pruner.step()
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 179, in step
for group in self.prune_local():
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 238, in prune_local
imp = self.estimate_importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/dfs/data/LLM-Pruner/LLMPruner/torch_pruning/pruner/algorithms/metapruner.py", line 183, in estimate_importance
return self.importance(group, ch_groups=ch_groups, consecutive_groups=consecutive_groups)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/dfs/data/LLM-Pruner/LLMPruner/pruner/hf_baichuan_pruner.py", line 306, in call
local_norm = local_norm[idxs]
IndexError: index 10240 is out of bounds for dimension 0 with size 5120
I modified "hf_prune_llama.py" and "LLMPruner/pruner/hf_llama_pruner.py":
1ใreplacing the model loading part as:
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(args.base_model)
2ใreplacing all "q_proj, k_proj, v_proj" with "W_pack"
Do you have any advice on quick fixing? Thank you very much!