Giter Site home page Giter Site logo

retrieve-rewrite-answer's Introduction

Retrieve-Rewrite-Answer: A KG-to-Text Enhanced LLMs Framework for Knowledge Graph Question Answering

Abstract Despite their competitive performance on knowledge-intensive tasks, large language models (LLMs) still have limitations in memorizing all world knowledge especially long tail knowledge. In this paper, we study the KG-augmented language model approach for solving the knowledge graph question answering (KGQA) task that requires rich world knowledge. Existing work has shown that retrieving KG knowledge to enhance LLMs prompting can significantly improve LLMs performance in KGQA. However, their approaches lack a well-formed verbalization of KG knowledge, i.e., they ignore the gap between KG representations and textual representations. To this end, we propose an answer-sensitive KG-to-Text approach that can transform KG knowledge into well-textualized statements most informative for KGQA. Based on this approach, we propose a KG-to-Text enhanced LLMs framework for solving the KGQA task. Experiments on several KGQA benchmarks show that the proposed KG-to-Text augmented LLMs approach outperforms previous KG-augmented LLMs approaches regarding answer accuracy and usefulness of knowledge statements.

This is the accompanying code & benchmarks for the paper "Retrieve-Rewrite-Answer: A KG-to-Text Enhanced LLMs Framework for Knowledge Graph Question Answering".

UPDATE: the paper has been accepted by the 12th International Joint Conference on Knowledge Graphs (IJCKG 2023).

Requirements

Please install the following dependency libraries.

  • accelerate == 0.21.0
  • bitsandbytes == 0.39.0
  • datasets == 2.16.1
  • deepspeed == 0.10.0
  • langchain == 0.0.247
  • numpy == 1.24.4
  • pandas ==2.0.3
  • peft ==0.5.0
  • simpletransformers ==0.64.3
  • torch == 2.0.1
  • tqdm == 4.65.0
  • transformers == 4.32.0
  • Python version == 3.8.17

Package Description

Important files/folders are described as follows:

Retrieve-Rewrite-Answer/main/
├─ corpus_generation/: KG-to-Text corpus generation
    ├─ MetaQA: Corpus generation for MetaQA
        ├─ data: Original MetaQA QA dataset
        ├─ indexes: KB, type annotation, and dict
        ├─ process: Process the original MetaQA QA dataset
           ├─ tripledict.py: Step 1. Generate dict files
           ├─ process.py: Step 2. Generate gold relation path and gold subgraph
        ├─ KG-to-Text: Generate KG-to-Text corpus based on ChatGPT
           ├─ sample.py: Step 1. Sample some data for corpus generation
           ├─ corpus_generation.py: Step 2. Generate corpus
           ├─ format.py: Step 3. Transform the generated corpus into Stanford Alpaca format
           ├─ data: MetaQA KG-to-Text corpus
    ├─ WQSP: Corpus generation for WQSP
        ├─ data: Original WQSP QA dataset
        ├─ indexes: KB, and dict
        ├─ process: Process the original WQSP QA dataset
           ├─ dict.py: Step 1. Generate dict files
           ├─ get_graph.py: Step 2. Generate gold subgraph
           ├─ en_re.py: Step 3. Generate relation file and name dict
           ├─ convert_name.py: Step 4. Convert mids to names of the entity in the gold subgraph 
           ├─ query_interface.py: Query entity name from Freebase 
        ├─ KG-to-Text: Generate KG-to-Text corpus based on ChatGPT
           ├─ corpus_generation.py: Step 1. Generate corpus
           ├─ format.py: Step 2. Transform the generated corpus into Standford Alpaca format
           ├─ data: WQSP KG-to-Text corpus
    ├─ ZJQA: Corpus generation for ZJQA
        ├─ data: Original ZJQA QA dataset
        ├─ indexes: KB, and dict
        ├─ process: Process the original ZJQA QA dataset
           ├─ process.py: Process the QA dataset
        ├─ KG-to-Text: Generate KG-to-Text corpus based on ChatGPT
           ├─ corpus_generation.py: Step 1. Generate corpus
           ├─ format.py: Step 2. Transform the generated corpus into Stanford Alpaca format
           ├─ data: ZJQA KG-to-Text corpus
├─ finetune-llama/: Finetune llama on the generated KG-to-Text corpus
    ├─ run_sft_chat-7b.sh: Run this shell to finetune llama-7b 
    ├─ run_sft_chat-13b.sh: Run this shell to finetune llama-13b 
    ├─ run_sft_chat_chinese-7b.sh: Run this shell to finetune Chinese-Alpaca-7b 
    ├─ run_sft_chat_chinese-13b.sh: Run this shell to finetune Chinese-Alpaca-13b 
    ├─ run_clm_sft_with_peft-7b.py: LoRA for llama-7b
    ├─ run_clm_sft_with_peft-13b.py: LoRA for llama-13b
    ├─ run_clm_sft_with_peft-chinese-7b.py: LoRA for Chinese-Alpaca-7b
    ├─ run_clm_sft_with_peft-chinese-13b.py: LoRA for Chinese-Alpaca-13b
    ├─ MetaQA: MetaQA KG-to-Text corpus in Stanford Alpaca format
    ├─ WQSP: WQSP KG-to-Text corpus in Stanford Alpaca format
    ├─ ZJQA: ZJQA KG-to-Text corpus in Stanford Alpaca format
├─ finetune-t5/: Finetune flan-t5 on the generated KG-to-Text corpus
    ├─ train.py: Run this file to finetune flan-t5 
    ├─ data: MetaQA KG-to-Text corpus in Stanford Alpaca format
├─ KGQA-MetaQA/: KGQA on MetaQA
    ├─ retrieve: Subgraph Retrieval
        ├─ train.py: Step 1. Train path prediction
        ├─ predict.py: Step 2. Predict relation path
        ├─ retrieve.py: Step 3. Triple sampling
    ├─ rewrite: KG-to-Text
        ├─ infer_llama.py: KG-to-Text based on llama
        ├─ infer_mvp.py: KG-to-Text based on mvp/mtl
        ├─ infer_t5-xl.py: KG-to-Text based on flan-t5
    ├─ answer: Knowledge Text Enhanced Reasoning
        ├─ answer_gpt_no.py: Answer question with no knowledge based on ChatGPT
        ├─ answer_gpt_text.py: Answer question with free-form text based on ChatGPT
        ├─ answer_gpt_triple.py: Answer question with triple-form text based on ChatGPT
        ├─ answer_llama_no.py: Answer question with no knowledge based on llama
        ├─ answer_llama_text.py: Answer question with free-form text based on llama
        ├─ answer_llama_triple.py: Answer question with triple-form text based on llama
├─ KGQA-WQSP/: KGQA on WQSP
    ├─ retrieve: Subgraph Retrieval
        ├─ hop_prediction: Hop prediction
            ├─ train.py: Train and predict hop number of the question
        ├─ path_prediction: Relation path prediction
            ├─ train.py: Step 1. Train relation path prediction
            ├─ predict.py: Step 2. Predict relation path
            ├─ retrieve.py: Step 3. Triple sampling
    ├─ rewrite: KG-to-Text
        ├─ infer_llama.py: KG-to-Text based on llama
    ├─ answer: Knowledge Text Enhanced Reasoning
        ├─ answer_t0.py: Answer question based on t0
        ├─ answer_t5.py: Answer question based on t5
├─ KGQA-ZJQA/: KGQA on ZJQA
    ├─ retrieve: Subgraph Retrieval
        ├─ hop_prediction: Hop prediction
            ├─ train.py: Train and predict hop number of the question
        ├─ path_prediction: Relation path prediction
            ├─ train.py: Step 1. Train relation path prediction
            ├─ predict.py: Step 2. Predict relation path
            ├─ retrieve.py: Step 3. Triple sampling
    ├─ rewrite: KG-to-Text
        ├─ infer_llama.py: KG-to-Text based on llama
    ├─ answer: Knowledge Text Enhanced Reasoning
        ├─ answer_gpt_no.py: Answer question with no knowledge based on ChatGPT
        ├─ answer_gpt_text.py: Answer question with free-form text based on ChatGPT
        ├─ answer_gpt_triple.py: Answer question with triple-form text based on ChatGPT
        ├─ answer_llama_no.py: Answer question with no knowledge based on llama
        ├─ answer_llama_text.py: Answer question with free-form text based on llama
        ├─ answer_llama_triple.py: Answer question with triple-form text based on llama

Resources

Processed data

Download indexes for WQSP to corpus_generation/WQSP: https://pan.baidu.com/s/19qDw3wfYq7nUf3MWjOln8g?pwd=l94c
We provide the processed dataset.
Download processed data to corpus_generation/MetaQA and rename it to "processed_data": https://pan.baidu.com/s/1B7RA8uFx972TTuwj79hu7g?pwd=pv2v
Download processed data to corpus_generation/WQSP and rename it to "processed_data": https://pan.baidu.com/s/1Dp0pSy-AdEhrb6bqJOQLRQ?pwd=iaqv
Download processed data to corpus_generation/ZJQA and rename it to "processed_data": https://pan.baidu.com/s/1TBIU3kVUGsuYQ2D7wXsjRw?pwd=77n3
We provide the retrieved result for WQSP. Download retrieved result to KGQA-WQSP/retrieve/path_prediction: https://pan.baidu.com/s/1vORsf1X6RhgkXjXO1vTfjw?pwd=21vd
We provide the retrieved result for MetaQA. Download retrieved result to KGQA-MetaQA/retrieve/ and rename it to "result": https://pan.baidu.com/s/1z-LYruiBzAx9P0p6IVsKuQ?pwd=ebut
The KGQA dataset for ZJQA is in corpus_generation/ZJQA/data and corresponding KG is in corpus_generation/ZJQA/indexes

LoRA checkpoint for KG-to-Text

Download llama LoRA checkpoint for KG-to-Text to finetune-llama: https://pan.baidu.com/s/1IdV7Fs4o12zwnjK49x1CcQ?pwd=ylwq
Download Flan-T5-xl LoRA checkpoint for KG-to-Text to finetune-flan-t5-xl: https://pan.baidu.com/s/1Ou1G2RwNK-bQ_k8EYMgX2w?pwd=z2sz

LLM

Download LLMs to pretrain: Llama-2-7b-chat, Llama-2-13b-chat, Chinese-Alpaca-2-7B, Chinese-Alpaca-2-13B, Flan-T5-small, Flan-T5-xl, Flan-T5-xxl, T5-large-lm-adapt, T5-xl-lm-adapt, T5-xxl-lm-adapt, T0, T0-3B, bert-base-uncased, bert-base-chinese

Usage

Corpus Generation

We provide our generated KG-to-Text corpus in corpus_generation/MetaQA/KG-to-Text/data, corpus_generation/WQSP/KG-to-Text/data and corpus_generation/ZJQA/KG-to-Text/data. We also provide these corpus in Stanford Alpaca format for direct finetuning in finetune-llama/MetaQA, finetune-llama/WQSP, finetune-llama/ZJQA, finetune-t5/data.
If you want to generate your KG-to-Text corpus, please follow these steps and be ready to spend a lot of money ;)

  1. Run the files in process sequentially as described in Package Description to process the QA data
    For WQSP, you need to build freebase in virtuoso to support entity names query. You can directly use our provided processed data to skip this step.
  2. Run the files in KG-to-Text sequentially as described in Package Description to generate the KG-to-Text corpus and transform it into Stanford Alpaca format

LLM finetuning

Llama finetuning

Run the correct shell for finetuning Llama in finetune-llama.
Please note you should choose different shells for finetuning different size and language Llama. You may need to modify some parameters (e.g. pretrained_model, batch_size).

Flan-T5 finetuning

Run train.py in finetune-t5.
You may need to modify some parameters (e.g. model_path, batch_size).

KGQA

Retrieve

Run the files in retrieve sequentially as described in Package Description to retrieve subgraph
For WQSP, you need to build freebase in virtuoso to support entity names query. You can directly use our provided retrieved result to skip this step.

Rewrite

Run infer_llama.py, infer_t5-xl.py and infer_mvp.py in rewrite to transform triple-form text into free-form text based on different LLMs.
You may need to modify the path for the model or output file.

Answer

Run the files in answer to answer the questions. For detailed usage, please refer to Package Description. You may need to modify the path for the model, input file or output file.

Contact

Please consider creating a new issue. We will respond to your questions within a few days.

BibTex

If you find this work is helpful for your research, please cite:

@article{wu2023retrieve,
  title={Retrieve-Rewrite-Answer: A KG-to-Text Enhanced LLMs Framework for Knowledge Graph Question Answering},
  author={Wu, Yike and Hu, Nan and Bi, Sheng and Qi, Guilin and Ren, Jie and Xie, Anhuan and Song, Wei},
  year={2023}
}

retrieve-rewrite-answer's People

Contributors

wuyike2000 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

retrieve-rewrite-answer's Issues

在微调llama时出现错误

以下是具体的报错内容,我运行代码前仅修改了os.environ["CUDA_VISIBLE_DEVICES"] = "0",是不支持单卡微调吗?

Traceback (most recent call last):
  File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py", line 419, in <module>
    main()
  File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py", line 392, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1904, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1893, in clip_grad_norm_
    self.unscale_gradients()
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1856, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 307, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 229, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

以及在双卡正常训练,未修改代码,进度达到epoch=1.0时似乎无法保存,

353 [INFO|trainer.py:3160] 2023-12-03 12:44:05,389 >>   Num examples = 1386
354 [INFO|trainer.py:3163] 2023-12-03 12:44:05,389 >>   Batch size = 4
355   File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py", line 419, in <module>
356 {'eval_loss': 0.13147063553333282, 'eval_runtime': 72.5555, 'eval_samples_per_second': 19.103, 'eval_steps_per_second': 2.398, 'epoch': 1.0}
357     main()
358   File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py", line 392, in main
359     train_result = trainer.train(resume_from_checkpoint=checkpoint)
360   File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
361     return inner_training_loop(
362   File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1937, in _inner_training_loop
363     self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
364   File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2282, in _maybe_log_save_evaluate
365     self._save_checkpoint(model, trial, metrics=metrics)
366   File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2350, in _save_checkpoint
367     self.save_model(output_dir, _internal_call=True)
368   File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2826, in save_model
369     raise ValueError("Install Accelerate from main branch")
370 ValueError: Install Accelerate from main branch
371 ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
372 │ /public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py:4 │
373 │ 19 in <module>                                                                                   │
374 │                                                                                                  │
375 │   416                                                                                            │
376 │   417                                                                                            │
377 │   418 if __name__ == "__main__":                                                                 │
378 │ ❱ 419 │   main()                                                                                 │
379 │   420                                                                                            │
380 │                                                                                                  │
381 │ /public2/home/wangchen/Retrieve-Rewrite-Answer-main/finetune-llama/run_clm_sft_with_peft-7b.py:3 │
382 │ 92 in main                                                                                       │
383 │                                                                                                  │
384 │   389 │   │   │   checkpoint = training_args.resume_from_checkpoint                              │
385 │   390 │   │   elif last_checkpoint is not None:                                                  │
386 │   391 │   │   │   checkpoint = last_checkpoint                                                   │
387 │ ❱ 392 │   │   train_result = trainer.train(resume_from_checkpoint=checkpoint)                    │
388 │   393 │   │                                                                                      │
389 │   394 │   │   metrics = train_result.metrics                                                     │
390 │   395                                                                                            │
391 │                                                                                                  │
392 │ /public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py:1555 in train │
393 │                                                                                                  │
394 │   1552 │   │   │   finally:                                                                      │
395 │   1553 │   │   │   │   hf_hub_utils.enable_progress_bars()                                       │
396 │   1554 │   │   else:                                                                             │
397 │ ❱ 1555 │   │   │   return inner_training_loop(                                                   │
398 │   1556 │   │   │   │   args=args,                                                                │
399 │   1557 │   │   │   │   resume_from_checkpoint=resume_from_checkpoint,                            │
400 │   1558 │   │   │   │   trial=trial,                                                              │
401 │                                                                                                  │
402 │ /public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py:1937 in       │
403 │ _inner_training_loop                                                                             │
404 │                                                                                                  │
405 │   1934 │   │   │   │   self.control.should_training_stop = True                                  │
406 │   1935 │   │   │                                                                                 │
407 │   1936 │   │   │   self.control = self.callback_handler.on_epoch_end(args, self.state, self.con  │
408 │ ❱ 1937 │   │   │   self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_  │
409 │   1938 │   │   │                                                                                 │
410 │   1939 │   │   │   if DebugOption.TPU_METRICS_DEBUG in self.args.debug:                          │
411 │   1940 │   │   │   │   if is_torch_tpu_available():                                              │
412 │                                                                                                  │
413 │ /public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py:2282 in       │
414 │ _maybe_log_save_evaluate                                                                         │
415 │                                                                                                  │
416 │   2279 │   │   │   │   self.lr_scheduler.step(metrics[metric_to_check])                          │
417 │   2280 │   │                                                                                     │
418 │   2281 │   │   if self.control.should_save:                                                      │
419 │ ❱ 2282 │   │   │   self._save_checkpoint(model, trial, metrics=metrics)                          │
420 │   2283 │   │   │   self.control = self.callback_handler.on_save(self.args, self.state, self.con  │
421 │   2284 │                                                                                         │
422 │   2285 │   def _load_rng_state(self, checkpoint):                                                │
423 │                                                                                                  │
424 │ /public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py:2350 in       │
425 │ _save_checkpoint                                                                                 │
426 │                                                                                                  │
427 │   2347 │   │                                                                                     │
428 │   2348 │   │   run_dir = self._get_output_dir(trial=trial)                                       │
429 │   2349 │   │   output_dir = os.path.join(run_dir, checkpoint_folder)                             │
430 │ ❱ 2350 │   │   self.save_model(output_dir, _internal_call=True)                                  │
431 │   2351 │   │   if self.is_deepspeed_enabled:                                                     │
432 │   2352 │   │   │   # under zero3 model file itself doesn't get saved since it's bogus! Unless d  │
433 │   2353 │   │   │   # config `stage3_gather_16bit_weights_on_model_save` is True                  │
434 │                                                                                                  │
435 │ /public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/trainer.py:2826 in       │
436 │ save_model                                                                                       │
437 │                                                                                                  │
438 │   2823 │   │   elif self.is_deepspeed_enabled:                                                   │
439 │   2824 │   │   │   # this takes care of everything as long as we aren't under zero3              │
440 │   2825 │   │   │   if version.parse(accelerate_version) <= version.parse("0.20.3"):              │
441 │ ❱ 2826 │   │   │   │   raise ValueError("Install Accelerate from main branch")                   │
442 │   2827 │   │   │   try:                                                                          │
443 │   2828 │   │   │   │   state_dict = self.accelerator.get_state_dict(self.deepspeed)              │
444 │   2829 │   │   │   │   if self.args.should_save:                                                 │
445 ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
446 ValueError: Install Accelerate from main branch

微调llama时数据集缓存时出现bug

报错信息如下,
Traceback (most recent call last):
File "D:\pyProjects2\Retrieve-Rewrite-Answer\finetune-llama\build_dataset.py", line 66, in build_instruction_dataset
processed_dataset = datasets.load_from_disk(cache_path)
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\load.py", line 2232, in load_from_disk
raise FileNotFoundError(
FileNotFoundError: Directory ZJQA\train\train is neither a Dataset directory nor a DatasetDict directory.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "D:\pyProjects2\Retrieve-Rewrite-Answer\finetune-llama\run_clm_sft_with_peft-chinese-7b.py", line 419, in
main()
File "D:\pyProjects2\Retrieve-Rewrite-Answer\finetune-llama\run_clm_sft_with_peft-chinese-7b.py", line 298, in main
train_dataset = build_instruction_dataset(
File "D:\pyProjects2\Retrieve-Rewrite-Answer\finetune-llama\build_dataset.py", line 69, in build_instruction_dataset
raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\load.py", line 2109, in load_dataset
builder_instance = load_dataset_builder(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\load.py", line 1795, in load_dataset_builder
dataset_module = dataset_module_factory(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\load.py", line 1404, in dataset_module_factory
return PackagedDatasetModuleFactory(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\load.py", line 947, in get_module
data_files = DataFilesDict.from_patterns(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\data_files.py", line 671, in from_patterns
DataFilesList.from_patterns(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\data_files.py", line 577, in from_patterns
resolve_pattern(
File "D:\Software\anaconda3\envs\KG\lib\site-packages\datasets\data_files.py", line 335, in resolve_pattern
protocol_prefix = fs.protocol + "://" if fs.protocol != "file" else ""
TypeError: can only concatenate tuple (not "str") to tuple

请问训练的流程是这样的吗?

以MetaQA为例,以下是我理解的训练流程,请问是这样训练吗


1. corpus_generation/MetaQA/process/process.py
2. corpus_generation/MetaQA/process/tripledict.py
3. 
4. corpus_generation/MetaQA/KG-to-Text/corpus_generation.py
5. corpus_generation/MetaQA/KG-to-Text/format.py
6. corpus_generation/MetaQA/KG-to-Text/sample.py
7. 
8. 进行llama微调
9. 
10. KGQA-MetaQA/retrieve/train.py
11. KGQA-MetaQA/retrieve/retrieve.py
12. KGQA-MetaQA/retrieve/predict.py
13. 
14. KGQA-MetaQA/rewrite/infer_llama.py
15. KGQA-MetaQA/rewrite/infer_t5-xl.py 
16. KGQA-MetaQA/rewrite/infer_mvp.py
17. 
18. KGQA-MetaQA/answer/answer_llama_no.py
19. KGQA-MetaQA/answer/answer_llama_triple.py
20. KGQA-MetaQA/answer/answer_llama_text.py

llama微调问题

基础模型为Chinese-Alpaca-2-7B,使用4090微调,batch_size修改为1仍然报错CUDA out of memory,我想问下最少需要多少内存?

请问这个文件在哪里呀?

Traceback (most recent call last):
File "retrieve.py", line 5, in
taildict=pickle.load(open('../../../corpus_generation/ZJQA/indexes/taildict.pkl','rb'))
FileNotFoundError: [Errno 2] No such file or directory: '../../../corpus_generation/ZJQA/indexes/taildict.pkl'

数据集问题

您好!非常感谢作者在论文以及代码中提供的新思路,目前我在做自己的疾病知识图谱,也希望用到代码进行KG-to-Text的工作。目前关于数据集上有一些疑问,我是根据ZJQA这个数据集进行自己数据的构建的
1.indexes/triple.txt(三元组)
2.KG-to-text/data:请问是用corpus_generation.py去做的吗?
3.data/:(问题,头实体,关系,尾实体,一个数字),请问这个数字是指什么意思呀?是指一跳两跳吗?
谁参加了第18届亚运会乒乓球比赛混双决赛 第18届亚运会 项目|成员 孙颖莎 2 这个2是什么意思?我该怎么制作这个数据集比较好?
期待您的指导意见!

有关生成的训练数据的问题

question
ques2

我在查看训练数据时,发现存在这个question的生成的数据存在6000+条,但是我发现这个问题对应的子图只有一个,所以是不是生成的数据出现了问题,我是否需要删除掉这个question的那些多余的训练数据。

在运行run_clm_sft_with_peft-7b.py文件时,报错缺少文件:FileNotFoundError: [Errno 2] No such file or directory: _dir_dev_json_default-2995dd276f7112c1_0.0.0_8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96.lock'

01/21/2024 15:27:00 - INFO - main - Training files: D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\MetaQA\dataset_dir\dev.json D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\MetaQA\dataset_dir\train.json
01/21/2024 15:27:00 - WARNING - root - building dataset...
Using custom data configuration default-2995dd276f7112c1
Loading Dataset Infos from C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\datasets\packaged_modules\json
Traceback (most recent call last):
File "D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\run_clm_sft_with_peft-7b.py", line 429, in
main()
File "D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\run_clm_sft_with_peft-7b.py", line 307, in main
train_dataset = build_instruction_dataset(
File "D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\build_dataset.py", line 65, in build_instruction_dataset
raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\datasets\load.py", line 2523, in load_dataset
builder_instance = load_dataset_builder(
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\datasets\load.py", line 2232, in load_dataset_builder
01/21/2024 15:27:11 - INFO - datasets.builder - Using custom data configuration default-2995dd276f7112c1
01/21/2024 15:27:11 - INFO - datasets.info - Loading Dataset Infos from C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\datasets\packaged_modules\json
builder_instance: DatasetBuilder = builder_cls(
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\datasets\builder.py", line 418, in init
with FileLock(lock_path):
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\filelock_api.py", line 297, in enter
self.acquire()
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\filelock_api.py", line 255, in acquire
self._acquire()
File "C:\Users\jiaojiaxing.conda\envs\RRA\lib\site-packages\filelock_windows.py", line 28, in _acquire
fd = os.open(self.lock_file, flags, self._context.mode)
FileNotFoundError: [Errno 2] No such file or directory: 'D:\jjx\Retrieve-Rewrite-Answer-main\Retrieve-Rewrite-Answer-main\finetune-llama\MetaQA\dataset_dir\dev\_jjx_Retrieve-Rewrite-Answer-main_Retrieve-Rewrite-Answer-main_finetune-llama_MetaQA_dataset_dir_dev_json_default-2995dd276f7112c1_0.0.0_8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96.lock'

Process finished with exit code 1

运行不带--do_train和--do_eval就不会报错 正常运行 但是 不会进行训练

进行reweite时报错

运行infer-llama时报错
我之前用自己微调的chatglm好像没出现这个问题,但是结果不太好
换成原本的llama后这里好像是解包txt出错。。。

Vocab of the base model: 32000
Vocab of the tokenizer: 32000
loading peft model
0it [00:00, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
0it [00:00, ?it/s]
Traceback (most recent call last):
  File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/KGQA-MetaQA/rewrite/infer_llama.py", line 75, in <module>
    generation_output = model.generate(
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/peft/peft_model.py", line 977, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1719, in generate
    return self.sample(
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2801, in sample
    outputs = self(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 922, in forward
    layer_outputs = decoder_layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/public2/home/wangchen/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 672, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/KGQA-MetaQA/rewrite/attn_and_long_ctx_patches.py", line 53, in xformers_forward
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/public2/home/wangchen/Retrieve-Rewrite-Answer-main/KGQA-MetaQA/rewrite/attn_and_long_ctx_patches.py", line 172, in adaptive_ntk_forward
    self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
IndexError: too many indices for tensor of dimension 2

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.