Hi, I encounter OOM when run the example in this repository, what's the minimum GPU memory requirements to run the example
WARNING:root:The `device_map` argument is not provided. We will override the device_map argument. to set the entire model on the current device. If you want to set the model on multiple devices, please provide a custom `device_map` argument.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:03<00:00, 1.56s/it]
/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py:257: UserWarning: No dataset is provided. Make sure to set config.batch_size to the correct value before training.
warnings.warn(
0%| | 0/5000 [00:00<?, ?it/s]/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
0%|β | 5/5000 [00:11<3:06:40, 2.24s/it]
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/haosdent/k8s-rl/ppo_by_llm/ppo_train.py", line 110, in <module>
train_stats = agent.terminate_episode()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/llamagym/agent.py", line 140, in terminate_episode
train_stats = self.train_batch(
^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/llamagym/agent.py", line 165, in train_batch
train_stats = self.ppo_trainer.step(queries, responses, rewards)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 788, in step
logprobs, logits, vpreds, _ = self.batched_forward_pass(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 984, in batched_forward_pass
logits, _, values = model(**input_kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 170, in forward
base_model_output = self.pretrained_model(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/peft/peft_model.py", line 1073, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 103, in forward
return self.model.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1176, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1019, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 755, in forward
hidden_states = self.mlp(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 241, in forward
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 414, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 563, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 404, in forward
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/haosdent/miniconda3/envs/localGPT/lib/python3.11/site-packages/bitsandbytes/functional.py", line 1816, in mm_dequant
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.00 MiB. GPU 0 has a total capacity of 23.50 GiB of which 9.69 MiB is free. Including non-PyTorch memory, this process has 23.47 GiB memory in use. Of the allocated memory 23.01 GiB is allocated by PyTorch, and 195.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)