Comments (4)
Nice, it turns out you were right, the .unsqueeze(0)
was indeed redundant. Love it, it makes the code even simpler and more readable!
from llms-from-scratch.
Also I have a question - could you please explain why do we need to call contiguous()
in the following line in MultiHeadAttention
class:
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
from llms-from-scratch.
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
Ah yes, this was unnecessary so I updated it to just mask_bool.unsqueeze(0)
a while back. I will look into whether I can remove it altogether like you suggest. Thanks!
Also I have a question - could you please explain why do we need to call contiguous() in the following line in MultiHeadAttention class:
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
Good question. This is because the way the memory is organized in this tensor; the .view()
would raise an error. What you could do is
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
This this is because (quoting from the documentation):
When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
However, I haven't used .reshape
elsewhere in this book so I wanted to stick with .view
for consistency.
from llms-from-scratch.
Sebastian, thanks a lot for your response,
Good question. This is because the way the memory is organized in this tensor; the .view() would raise an error
Yes, this question was asked because when I deleted .contiguous()
:
context_vec = context_vec.view(b, num_tokens, self.d_out)
I didn't have any errors and get the same results.
Only one another reason to convert to contiguous tensor that I found here was the following:
This create issues with parallel computations.
But I didn't find more detailed explanation.
Could you please share your thoughts about it?
Thank you.
from llms-from-scratch.
Related Issues (20)
- In 3.3.1, there seems to be a missing image between "The attention weights and context vector calculation are summarized in the figure below:" and "The code below walks through the figure above step by step." HOT 1
- RuntimeError: size mismatch - ch05/03_bonus_pretraining_on_gutenberg HOT 2
- book feedback HOT 1
- The definition of stride is confusing in 2.6 HOT 3
- Difference btwn book and repo HOT 1
- class MHAPyTorchScaledDotProduct HOT 2
- suggestion of adding torch.profile
- do have a doc for hardware specs HOT 2
- ch06/03_bonus_imdb-classification HOT 16
- Expected all tensors to be on the same device HOT 1
- Throwing error for longer textual data like 9599 HOT 4
- Solution for Exercise 3.3 is included in the notebook with main code (3.6.2 Implementing multi-head attention with weight splits) HOT 1
- Inconsistencies in MHA Wrapper Implementation Between Chapter 3 Main Content and Bonus Material HOT 1
- Offering Chinese Translation for 'Build a Large Language Model From Scratch HOT 3
- Chapter 5 - Context Size and the DataLoaders HOT 2
- Feedback: Stripe output from notebook HOT 2
- About endoftext in ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py HOT 14
- Contributions for Chinese simplified version HOT 4
- {Q} : Replacing the LlamaDecoderLayer Class hugging Face With New LongNet
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from llms-from-scratch.