Giter Site home page Giter Site logo

Comments (6)

rasbt avatar rasbt commented on August 20, 2024 1

Thanks for reporting! And I could swear I wrote a response to this earlier ... not sure what happened to it. In any case, I think this could be a PyTorch bug.

If you change it to

checkpoint = torch.load("model_and_optimizer.pth")
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)   # NEW

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train();

could you check whether it solves it for you? It did the trick for me.

from llms-from-scratch.

d-kleine avatar d-kleine commented on August 20, 2024

Just tested, I have the same issue (see ch05/01_main-chapter-code/exercise-solutions.ipynb)

As the model will already be moved on GPU (when there is one that is CUDA-supported), you could also move the states of the optimizer to the GPU as well:

# Iterate over all states in the optimizer
for state in optimizer.state.values():
    # Iterate over all key-value pairs in the current state
    for param_name, param_value in state.items():
        # Check if the value is a torch.Tensor
        if isinstance(param_value, torch.Tensor):
            # Move the tensor to the specified device (e.g., GPU or CPU)
            state[param_name] = param_value.to(device)

from llms-from-scratch.

d-kleine avatar d-kleine commented on August 20, 2024

At least for me, it worked. And I agree, seems like a bug.

Two notes on the exercises:

  • There is still a duplicated model.to(device)
  • The GPT2-XL model might be too heavy, it used all of my RAM (32GB) and froze my computer for some seconds.

from llms-from-scratch.

frankchieng avatar frankchieng commented on August 20, 2024

At least for me, it worked. And I agree, seems like a bug.

Two notes on the exercises:

  • There is still a duplicated model.to(device)
  • The GPT2-XL model might be too heavy, it used all of my RAM (32GB) and froze my computer for some seconds.

tbh,i did gpt2-3 pre-trained model from scratch recently in another version,compared with different GPU clusters,24G VRAM like RTX4090 or A10G will be OOM when i added the hellaswag evaluation part.IMHO,i did some test, if you trained with the HuggingFaceFW/fineweb-edu datasets on HF, i mean the minimum edu_fineweb10B, then you need to have at least around 48G VRAM like the L20 8xGPU, but it will be slower than 8xGPU A100, around half FLOPS computation disadvantages.Anyway nowadays i thought 8x 80G A100 is a standard requirements for LLM pre_training model

from llms-from-scratch.

rasbt avatar rasbt commented on August 20, 2024

There is still a duplicated model.to(device)

Thanks for the note, I'll clean that up. But it should not hurt, PyTorch is smart enough to detect if the model is already on the target device and then skips this op.

The GPT2-XL model might be too heavy, it used all of my RAM (32GB) and froze my computer for some seconds.

Yeah. Pretraining can be expensive. We trained TinyLlama on 64xA100s 😅. Besides lowering context length and batch size, you can try to train in lower precision if your GPU supports it:

model.to(device=device, dtype=torch.bfloat16)

I can also add some FSDP code some time in the future

@frankchieng

Btw does the added model.to(device) call fix things for you too?

from llms-from-scratch.

rasbt avatar rasbt commented on August 20, 2024

I think this issue is resolved now. But please feel free to reopen & respond if you are still having issues. Thanks again for reporting!

from llms-from-scratch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.