Hi! It's me again.
I'm creating an annotated version of the UNet in lesson #7 (diffusion models). I'm adding more comments + assertions for the shapes of all inputs/outputs/weights/intermediate steps.
While doing this, I noticed there might be a mistake in some of the comments?
Here's the code that runs the UNet on dummy data (from the lesson):
# A dummy batch of 10 3-channel 32px images
x = torch.randn(10, 3, 32, 32)
# 't' - what timestep are we on
t = torch.tensor([50], dtype=torch.long)
# Define the unet model
unet = UNet()
# The foreward pass (takes both x and t)
model_output = unet(x, t)
Inside the actual UNet this is the forwad pass
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size]`
"""
# Get time-step embeddings
t = self.time_emb(t)
It says that the shape of t
is [batch_size]
. But the shape of t
is 1, which is to be expected if we look at the code that is testing the UNet.
Specifically, the assertion:
batch_size = x.shape[0]
print(t.shape)
assert t.shape[0] == batch_size
fails.
I'm not sure exactly what's going on here. My hypothesis is as follows: The UNet is being trained on a batch of images. Each image in the batch should be accompanied by its own time step number. However, it looks like only a single time-step is being passed into the UNet.
Somewhere along the line, this time-step is being accidentally broad-casted by Pytorch to fit the batch dimension and being used as the time-step for all images.
Does that sound correct to you?