Giter Site home page Giter Site logo

latent-diffusion's People

Contributors

ak391 avatar crowsonkb avatar pesser avatar rromb avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

latent-diffusion's Issues

Text2Image Training

Maybe I'm an idiot, but if it's not already somewhere in the repo, are you intending to release the script you used to train the text2image model? And what specs were needed to train it in the first place? Thanks!

Model size

Thanks for your interesting work.
How about the model size of LDM, compared with StyleGAN and ProjectedGAN?

Non-descriptive error when sampling in Colab

When I run !python scripts/txt2img.py --prompt "a virus monster is playing guitar, oil on canvas" --ddim_eta 0.0 --n_samples 4 --n_iter 4 --scale 5.0 --ddim_steps 50

It produces

Loading model from models/ldm/text2img-large/model.ckpt
1
2
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
^C

I have added in the 1 and 2 to see where the script is failing by modifying it like so

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    print(1)
    sd = pl_sd["state_dict"]
    print(2)
    model = instantiate_from_config(config.model)
    print(3)
    m, u = model.load_state_dict(sd, strict=False)
    print(4)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

What is that ^C and how do I debug this?

Training steps for autoencoder training

Thanks for sharing your great project.

How many steps did you train your autoencoders on ImageNet? It seems that there is no description of the number of training steps or epochs in the provided configuration files. Sorry if it's mentioned elsewhere.

Reproducing inpainting results

Hi,
thanks for this great repo! I was trying to reproduce the inpainting results on the example images and obtain noticeable artifacts.
image
image

Do you have an idea what could be the reason? I am running:
python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inpainting_results

super resolution example

Thanks for great model :).
I read readme file but couldn't find super_resolution example like scripts/inpaint.py.
It looks like super-resolution can be done with notebook_helpers.py
Do you have plan to support sr example?

Question about some techniques

Hi @pesser!
Thank you for sharing the implementation of your wonderful work!

I have questions about some techniques.
Would you tell me these questions?

I have used your pretrained celeba256 weight.
The images were recorded using such as intermediates['x_inter'].append(img).

  1. Why do you step the time? According to this line, it seems you choose time values for each num_ddpm_timesteps // num_ddim_timesteps.
    Actually, I have never seen this technique.

  2. If I do not step the above values and T: 1000 -> 0, i.e. the time steps are continuous and have ranged from 1 to 1000, I cannot get clear results. This image was recorded in six separate 1000 iterations.
    This image was recorded using x_inter.
    image

This image was recorded using pred_x0.
image

  1. If the time step is fixed to default values in this line and the start time is decreased such as to 800, I cannot clear results.
    Why? Your method cannot work well other than t=1000? (actually, if t=1000, 50 iterations because time steps are split)
    image

  2. If your method cannot perform the question (3), your method cannot perform this unique denoising technique as shown in Sohl-Dickstein+ ICML15? Can you possibly accomplish it?
    image

Best regards,
Udon

Colab Notebook example fails on weight inputs for conv.py

Was attempting to run the Colab for this project to gauge functionality. Here is the error it threw when running the model:

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    441                             _pair(0), self.dilation, self.groups)
    442         return F.conv2d(input, weight, bias, self.stride,
--> 443                         self.padding, self.dilation, self.groups)
    444 
    445     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 4, 128, 128] to have 3 channels, but got 4 channels instead

I have seen this before when there is a mismatch in pyTorch libraries. Are we sure the dependencies are accurate?

Question about training stability

Hello, thank you so much for this wonderful paper and codebase. I am trying to reproduce the results of lsun_churches-ldm-kl-8.yaml. I have not modified any parameters in the config and I am using your pretrained first stage model.

However, some part of training is not working correctly -- the losses are not decreasing as expected.

My loss curves are below:
Loss curves

Do you know what might be going wrong here? I feel like I have done something incorrectly, but I believe that I followed the instructions closely.

Thank you for your help!

Error in the configuration files for SIS task

Hi,

I notice that the config.taml files in models/ldm/semantic_synthesis256 or models/ldm/semantic_synthesis512 follow the configuration of LDM-4.
However, I noticed that in p25 and p5 of your paper, you've stated that LDM for semantic image synthesis is the LDM-8.
Although I can change the hyperparameters for the models thanks to the detailed description, I still need the hyperparameters related to training/optimizing the model.
Could you provide the correct configuration files?

terminate called after throwing an instance of 'c10::Error'

I am playing with ldm.models.diffusion.ddpm.LatentDiffusion with 4 GPUs and DDP distribution. After around 30 epochs, it stopped,

`terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: initialization error
Exception raised from insert_events at /opt/conda/conda-bld/pytorch_1603729096996/work/c10/cuda/CUDACachingAllocator.cpp:717 (most recent call first):
frame_#0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f082820c8b2 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/lib/libc10.so)
frame
#1: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1070 (0x7f082845ef20 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame
#2: c10::TensorImpl::release_resources() + 0x4d (0x7f08281f7b7d in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/lib/libc10.so)
frame
#3: + 0x5f65b2 (0x7f08725575b2 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame
#4: + 0x13c2bc (0x55b1c22232bc in /root/miniconda3/envs/ldm/bin/python)
frame
#5: + 0x1efd35 (0x55b1c22d6d35 in /root/miniconda3/envs/ldm/bin/python)
frame
#_6: PyObject_GC_Malloc + 0x88 (0x55b1c2223998 in /root/miniconda3/envs/ldm/bin/python)
frame
#7: PyType_GenericAlloc + 0x3b (0x55b1c2293a8b in /root/miniconda3/envs/ldm/bin/python)
frame
#8: + 0xc385 (0x7f08a1bbf385 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/bit_generator.cpython-38-x86_64-linux-gnu.so)
frame
#9: + 0x13d585 (0x55b1c2224585 in /root/miniconda3/envs/ldm/bin/python)
frame
#10: + 0xf97f (0x7f08a1bc297f in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/bit_generator.cpython-38-x86_64-linux-gnu.so)
frame
#11: + 0xfb7e (0x7f08a1bc2b7e in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/bit_generator.cpython-38-x86_64-linux-gnu.so)
frame
#12: + 0x1e857 (0x7f08a1bd1857 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/bit_generator.cpython-38-x86_64-linux-gnu.so)
frame
#13: + 0x5f92c (0x55b1c214692c in /root/miniconda3/envs/ldm/bin/python)
frame
#14: + 0x16fb40 (0x55b1c2256b40 in /root/miniconda3/envs/ldm/bin/python)
frame
#_15: + 0xe4d6 (0x7f08a17a84d6 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/mt19937.cpython-38-x86_64-linux-gnu.so)
frame
#16: + 0x13d60c (0x55b1c222460c in /root/miniconda3/envs/ldm/bin/python)
frame
#17: + 0x14231 (0x7f08a1bf4231 in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/mtrand.cpython-38-x86_64-linux-gnu.so)
frame
#18: + 0x21d0e (0x7f08a1c01d0e in /root/miniconda3/envs/ldm/lib/python3.8/site-packages/numpy/random/mtrand.cpython-38-x86_64-linux-gnu.so)
frame
#_19: PyObject_MakeTpCall + 0x1a4 (0x55b1c22247d4 in /root/miniconda3/envs/ldm/bin/python)
frame
#_20: PyEval_EvalFrameDefault + 0x4596 (0x55b1c22abf56 in /root/miniconda3/envs/ldm/bin/python)
frame
#_21: PyEval_EvalCodeWithName + 0x2d2 (0x55b1c2271a92 in /root/miniconda3/envs/ldm/bin/python)
frame
#_22: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#23: + 0x18be79 (0x55b1c2272e79 in /root/miniconda3/envs/ldm/bin/python)
frame
#24: PyVectorcall_Call + 0x71 (0x55b1c2224041 in /root/miniconda3/envs/ldm/bin/python)
frame
#_25: PyEval_EvalFrameDefault + 0x1fdb (0x55b1c22a999b in /root/miniconda3/envs/ldm/bin/python)
frame
#_26: PyEval_EvalCodeWithName + 0x7df (0x55b1c2271f9f in /root/miniconda3/envs/ldm/bin/python)
frame
#_27: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#28: + 0x18be79 (0x55b1c2272e79 in /root/miniconda3/envs/ldm/bin/python)
frame
#29: PyVectorcall_Call + 0x71 (0x55b1c2224041 in /root/miniconda3/envs/ldm/bin/python)
frame
#_30: PyEval_EvalFrameDefault + 0x1fdb (0x55b1c22a999b in /root/miniconda3/envs/ldm/bin/python)
frame
#_31: PyEval_EvalCodeWithName + 0x7df (0x55b1c2271f9f in /root/miniconda3/envs/ldm/bin/python)
frame
#_32: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#_33: PyObject_FastCallDict + 0x24b (0x55b1c22734cb in /root/miniconda3/envs/ldm/bin/python)
frame
#_34: PyObject_Call_Prepend + 0x63 (0x55b1c2273733 in /root/miniconda3/envs/ldm/bin/python)
frame
#35: + 0x18c83a (0x55b1c227383a in /root/miniconda3/envs/ldm/bin/python)
frame
#36: PyObject_Call + 0x70 (0x55b1c2224200 in /root/miniconda3/envs/ldm/bin/python)
frame
#_37: PyEval_EvalFrameDefault + 0x1fdb (0x55b1c22a999b in /root/miniconda3/envs/ldm/bin/python)
frame
#_38: PyEval_EvalCodeWithName + 0x2d2 (0x55b1c2271a92 in /root/miniconda3/envs/ldm/bin/python)
frame
#_39: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#_40: PyObject_FastCallDict + 0x24b (0x55b1c22734cb in /root/miniconda3/envs/ldm/bin/python)
frame
#_41: PyObject_Call_Prepend + 0x63 (0x55b1c2273733 in /root/miniconda3/envs/ldm/bin/python)
frame
#42: + 0x18c83a (0x55b1c227383a in /root/miniconda3/envs/ldm/bin/python)
frame
#_43: PyObject_MakeTpCall + 0x22f (0x55b1c222485f in /root/miniconda3/envs/ldm/bin/python)
frame
#_44: PyEval_EvalFrameDefault + 0x11d0 (0x55b1c22a8b90 in /root/miniconda3/envs/ldm/bin/python)
frame
#_45: PyFunction_Vectorcall + 0x10b (0x55b1c227286b in /root/miniconda3/envs/ldm/bin/python)
frame
#46: + 0xba0de (0x55b1c21a10de in /root/miniconda3/envs/ldm/bin/python)
frame
#47: + 0x17eb32 (0x55b1c2265b32 in /root/miniconda3/envs/ldm/bin/python)
frame
#48: PyObject_GetItem + 0x49 (0x55b1c22568c9 in /root/miniconda3/envs/ldm/bin/python)
frame
#_49: PyEval_EvalFrameDefault + 0xbdd (0x55b1c22a859d in /root/miniconda3/envs/ldm/bin/python)
frame
#_50: PyEval_EvalCodeWithName + 0x659 (0x55b1c2271e19 in /root/miniconda3/envs/ldm/bin/python)
frame
#_51: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#52: + 0xfeb84 (0x55b1c21e5b84 in /root/miniconda3/envs/ldm/bin/python)
frame
#_53: PyEval_EvalCodeWithName + 0x7df (0x55b1c2271f9f in /root/miniconda3/envs/ldm/bin/python)
frame
#_54: PyFunction_Vectorcall + 0x1e3 (0x55b1c2272943 in /root/miniconda3/envs/ldm/bin/python)
frame
#55: + 0x10075e (0x55b1c21e775e in /root/miniconda3/envs/ldm/bin/python)
frame
#_56: PyFunction_Vectorcall + 0x10b (0x55b1c227286b in /root/miniconda3/envs/ldm/bin/python)
frame
#57: PyVectorcall_Call + 0x71 (0x55b1c2224041 in /root/miniconda3/envs/ldm/bin/python)
frame
#_58: PyEval_EvalFrameDefault + 0x1fdb (0x55b1c22a999b in /root/miniconda3/envs/ldm/bin/python)
frame
#_59: PyFunction_Vectorcall + 0x10b (0x55b1c227286b in /root/miniconda3/envs/ldm/bin/python)
frame
#60: + 0x10075e (0x55b1c21e775e in /root/miniconda3/envs/ldm/bin/python)
frame
#_61: PyEval_EvalCodeWithName + 0x2d2 (0x55b1c2271a92 in /root/miniconda3/envs/ldm/bin/python)
frame
#62: + 0x18bd20 (0x55b1c2272d20 in /root/miniconda3/envs/ldm/bin/python)
frame
#_63: + 0x10011a (0x55b1c21e711a in /root/miniconda3/envs/ldm/bin/python)

Epoch 37: 69%|\u258b| 227/328 [18:34<08:13, 4.89s/it, loss=0.794, v_num=2, train/loss_simple_step=0.792, train/loss_vlb_step=0.0081, traTraceback (most recent call last):
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 101, in run
super().run(batch, batch_idx, dataloader_idx)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 148, in advance
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 202, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 396, in _optimizer_step
model_ref.optimizer_step(
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1618, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 209, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 129, in __optimizer_step
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 296, in optimizer_step
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 303, in run_optimizer_step
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 226, in optimizer_step
optimizer.step(closure=lambda_closure, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/optim/adamw.py", line 65, in step
loss = closure()
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 236, in _training_step_and_backward_closure
result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 537, in training_step_and_backward
result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 307, in _training_step
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 193, in training_step
return self.training_type_plugin.training_step(*step_kwargs.values())
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 383, in training_step
return self.model(*args, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 619, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward
output = self.module.training_step(*inputs, **kwargs)
File "/root/Desktop/ldm/ldm/models/diffusion/ddpm.py", line 343, in training_step
loss, loss_dict = self.shared_step(batch)
File "/root/Desktop/ldm/ldm/models/diffusion/ddpm.py", line 887, in shared_step
x, c = self.get_input(batch, self.first_stage_key)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/root/Desktop/ldm/ldm/models/diffusion/ddpm.py", line 661, in get_input
z = self.get_first_stage_encoding(encoder_posterior).detach()
File "/root/Desktop/ldm/ldm/models/diffusion/ddpm.py", line 544, in get_first_stage_encoding
z = encoder_posterior.sample()
File "/root/Desktop/ldm/ldm/modules/distributions/distributions.py", line 36, in sample
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
File "/root/miniconda3/envs/ldm/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 21388) is killed by signal: Aborted.
`

I am sure it is related to this issue , but unable to fix by setting rank_zero_only=True.

Any help is appreciated

Preventing/constraining words in the output

I asked LD for "fire" (samples=3, iter=2) and 5 of the 6 outputs had rendered some variant of the word "FIRE". Is it possible to somehow control whether text is rendered or not? Sometimes it's ideal (generating posters, books, etc.) but sometimes it ruins the render (cf "fire" above).

Or is it just a case that text from the training set is being picked up and there's nothing that can be done (other than training on non-textual sources)?

cannot load vq-f4 model

All the vq models work for me except the first one at https://ommer-lab.com/files/latent-diffusion/vq-f4.zip

using this config:

model:
  base_learning_rate: 4.5e-06
  target: ldm.models.autoencoder.VQModel
  params:
    embed_dim: 3
    n_embed: 8192
    monitor: val/rec_loss
    ddconfig:
      double_z: false
      z_channels: 3
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        disc_conditional: false
        disc_in_channels: 3
        disc_start: 0
        disc_weight: 0.75
        codebook_weight: 1.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 8
    num_workers: 16
    wrap: true
    train:
      target: ldm.data.openimages.FullOpenImagesTrain
      params:
        crop_size: 256
    validation:
      target: ldm.data.openimages.FullOpenImagesValidation
      params:
        crop_size: 256

code:

config = OmegaConf.load('./vq-f4/config.yaml')
pl_sd = torch.load('./vq-f4/model.ckpt', map_location="cpu")
sd = pl_sd["state_dict"]
ldm = instantiate_from_config(config.model)
ldm.load_state_dict(sd, strict=False)

error:

RuntimeError: Error(s) in loading state_dict for VQModel:
	size mismatch for encoder.down.1.block.0.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for encoder.down.1.block.0.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.block.0.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.block.1.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.1.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.block.1.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.downsample.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.downsample.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for encoder.down.2.block.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for encoder.down.2.block.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for encoder.down.2.block.0.nin_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for encoder.down.2.block.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for encoder.down.2.block.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([8, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 512, 3, 3]).
	size mismatch for encoder.conv_out.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([3]).
	size mismatch for decoder.conv_in.weight: copying a param with shape torch.Size([512, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 3, 3, 3]).
	size mismatch for decoder.up.0.block.0.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.0.block.0.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.0.block.0.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for decoder.up.1.block.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.1.block.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.1.block.0.conv1.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for decoder.up.1.block.0.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.0.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for decoder.up.1.block.0.nin_shortcut.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.1.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.1.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.2.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.2.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.2.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.upsample.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.upsample.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.2.block.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.upsample.conv.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.upsample.conv.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for loss.discriminator.main.8.weight: copying a param with shape torch.Size([1, 256, 4, 4]) from checkpoint, the shape in current model is torch.Size([512, 256, 4, 4]).
	size mismatch for quantize.embedding.weight: copying a param with shape torch.Size([16384, 8]) from checkpoint, the shape in current model is torch.Size([8192, 3]).
	size mismatch for quant_conv.weight: copying a param with shape torch.Size([8, 8, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 3, 1, 1]).
	size mismatch for quant_conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([3]).
	size mismatch for post_quant_conv.weight: copying a param with shape torch.Size([8, 8, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 3, 1, 1]).
	size mismatch for post_quant_conv.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([3]).

add web demo/models to Huggingface

Hi, would you be interested in adding latent-diffusion to Hugging Face? The Hub offers free hosting, and it would make your work more accessible and visible to the rest of the ML community. Models/datasets/spaces(web demos) can be added to a user account or organization similar to github.

Example from other organizations:
Keras: https://huggingface.co/keras-io
Microsoft: https://huggingface.co/microsoft
Facebook: https://huggingface.co/facebook

Example spaces with repos:
github: https://github.com/salesforce/BLIP
Spaces: https://huggingface.co/spaces/salesforce/BLIP

github: https://github.com/facebookresearch/omnivore
Spaces: https://huggingface.co/spaces/akhaliq/omnivore

and here are guides for adding spaces/models/datasets to your org

How to add a Space: https://huggingface.co/blog/gradio-spaces
how to add models: https://huggingface.co/docs/hub/adding-a-model
uploading a dataset: https://huggingface.co/docs/datasets/upload_dataset.html

Please let us know if you would be interested and if you have any questions, we can also help with the technical implementation.

Models on COCO dataset

Hi, thank you for the wonderful paper and available code. The unconditional models you have released are all trained on very specific dataset, like faces and building. I know it was not mentioned in the paper but I wonder in any case did you try to train unconditional model on COCO dataset which contains more various objects? If you did, what's the performance look like and can you shared the pre-trained model?

Thank you!

question about disc_start

Thanks for this great work.

I am trying to train vq models with custom data, only realized that disc_start in vq configs are very different,
for example,
vq-f8-n256 disc_start: 250001
vq-f8 disc_start: 1

Any particular reasons that discrimator could start from 1?

All checkpoints and links on ommer-lab 404 not found

Hey thanks for all your work and the excellent readme!

There seems to be an issue with all files having moved or disappeared from https://ommer-lab.com/files/latent-diffusion/* , all the links are 404-ing now. The heibox links still work fine.

If the hosting is going to be an issue, it would be nice if the checkpoints were all uploaded as a Github release on this repo https://github.com/CompVis/latent-diffusion/releases , this way Github will cover the hosting indefinitely and it doesn't have to be a worry anymore for any future maintenence. See SwinIR's release as an example: https://github.com/JingyunLiang/SwinIR/releases

I'm super interested in text2img so hopefully these can be restored πŸ˜ƒ .
Thanks again!

Alllow inference on CPU...

Tried to allocate 128.00 MiB (GPU 0; 7.79 GiB total capacity; 6.19 GiB already allocated; 81.44 MiB free; 6.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.

Setting CUDA_VISIBLE_DEVICES to -1 to force CPU results in no cuda devices found.

Please allow CPU inference at the expense of time.

the stability of training(a collapse loss)

Thanks for the great work.
I try to train the ldm model on ImageNet with 8 V100, but get a bad result.I found that loss was normal at first, but soon collapsed:

image
image

and the sampled image are all noise at 5000 steps:
image

How can I solve this problem, thank you very much

License

Thank you for the awesome work!
What is the license the models and code are released under?

Evaluating first stage autoencoders

First of all, thank you so much for making this high-quality repository as well as pretrained models publicly available! This is highly useful for exploring your research.

I am currently training first stage autoencoders on a custom dataset (SoundCloud images) and am struggling with evaluating these models (other than with the loss values logged in TensorBoard). I plan to compare the performance of initializing the autoencoder weights randomly vs. fine-tuning one of your pretrained autoencoders.

I would prefer to calculate rFID, PSNR, and PSIM the same way as you did for your results table. Could you please provide a hint as to how you evaluate your autoencoders? Is there some other repository or toolkit that you rely on?

Conda env yaml should be changed

Thanks for great research, I found out conda env settings error while using scripts below.

conda env create -f environment.yaml
conda activate ldm
# environment.yaml
name: ldm
...
dependencies:
  ...
  - pytorch=1.7.0
  - torchvision=0.8.1
  - pip:
    ...
    - pytorch-lightning==1.4.2
    ...

pytorch-lightning==1.4.2 automatically imports from torchmetrics.utilities.data import get_num_classes as _get_num_classes but that function was dropped by this PR.

So yaml file should be changed by updating pytorch & torchvision & pytorch-lightning or add explicit torchmetric version


Error Log

(ldm) ubuntu@nipa2021-19981:~/jwk/latent-diffusion$ python scripts/txt2img.py --prompt "a sunset behind a mountain range, vector image" --ddim_eta 1.0 --n_samples 1 --n_iter 1 --H 384 --W 1024 --scale 5.0  
Loading model from models/ldm/text2img-large/model.ckpt
Traceback (most recent call last):
  File "scripts/txt2img.py", line 101, in <module>
    model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt")  # TODO: check path
  File "scripts/txt2img.py", line 18, in load_model_from_config
    model = instantiate_from_config(config.model)
  File "/home/ubuntu/jwk/latent-diffusion/ldm/util.py", line 78, in instantiate_from_config
    return get_obj_from_str(config["target"])(**config.get("params", dict()))
  File "/home/ubuntu/jwk/latent-diffusion/ldm/util.py", line 86, in get_obj_from_str
    return getattr(importlib.import_module(module, package=None), cls)
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 783, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/ubuntu/jwk/latent-diffusion/ldm/models/diffusion/ddpm.py", line 12, in <module>
    import pytorch_lightning as pl
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/__init__.py", line 20, in <module>
    from pytorch_lightning import metrics  # noqa: E402
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/__init__.py", line 15, in <module>
    from pytorch_lightning.metrics.classification import (  # noqa: F401
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/__init__.py", line 14, in <module>
    from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 18, in <module>
    from pytorch_lightning.metrics.utils import deprecated_metrics, void
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/utils.py", line 22, in <module>
    from torchmetrics.utilities.data import get_num_classes as _get_num_classes
ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/torchmetrics/utilities/data.py)

autoencoder for LDM

Hi!
Could you put which autoencoding models correspond to which LDMs on the table, please?
Maybe I am missing this information somewhere, but it seems it's not clear which one is for which.

Same seed produces different outputs

Prompt: winter, sunrise, path in the forest, painted by Caspar David Friedrich (royaltyfree)
Steps: 50
ETA: 0
Iterations: 1
Width: 384
Height: 256
Samples_in_parallel: 1
Diversity_scale: 5
PLMS_sampling: 0
Seed: 1

Download (27)

re-run with same parameters:

Download (26)

The problem with the checkpoint finetuning.

Hi. First of all, thank you for this wonderful repository. I try to run a training and have the following problem:

I downloaded a small part of the imagenet dataset (2Gb) and unzipped it. There were only images, so I had to change the "./ldm/data/imagenet.py" a bit to be able to load my dataset. The output gave example["image"] and example["LR_image"] as required.

Then I fixed a few lines in "./models/ldm/bsr_sr/config.yaml", namely in train and validation I changed target to the path to imagenet.py.

Then I downloaded your ckpt file from notebook_helpers.py and decided to try to finetune the weight.

CUDA_VISIBLE_DEVICES=0 python main.py
--base "./models/ldm/bsr_sr/config.yaml"
--name "test"
--resume_from_checkpoint "./logs/diffusion/superresolution_bsr/last.yaml/?dl=1"
-t --gpus=0

But I got an error:

RuntimeError: Error(s) in loading state_dict for LatentDiffusion:
Unexpected key(s) in state_dict: "ddim_sigmas", "ddim_alphas", "ddim_alphas_prev", "ddim_sqrt_one_minus_alphas".

If I read the weights, delete those 4 keys and write to a new file, the training starts fine. Do I understand correctly that without them, the training will not work good? If I start the training from scratch, the resulting checkpoints will not contain these 4 keys at all. Can you tell me what I'm doing wrong?


And another small question: I separately trained the autoencoder (first_stage_models), got the checkpoint, but I can't find where to specify it when training the diffusion model (ldm). Perhaps the autoencoder is not involved in this step, but then where do I specify it if I want to run an inference with my weights?

config file for conditional LDM

thank you for sharing this great work.

where can I find config files for these unconditional tasks, such as Text-conditional Image Synthesis, Super-resolution, Layout-to-Image Synthesis and Semantic Image Synthesis?

the download links are merely ckpt files, and config files at configs/latent-diffusion are all unconditional tasks.

pip error

when create ldm
pip enconter the following error:

ERROR: Requested clip from git+https://github.com/openai/CLIP.git@main#egg=cli (from -r requirements.txt (line 14)) has different name in metadata: 'clip'

Constraining the output to within the borders?

(Might be able to be solved as part of #34 where e.g. transparent areas are forbidden?)

I'm generating movie posters / book covers / etc. and most of the time, the output is off the edge of the image (see attachment.)

Would be super if there was a way to hint / constrain the output - it shouldn't have seen anything cut-off like that in the training sets, I think? VQGAN-CLIP doesn't have this issue (but also isn't generating as good output in as quick a time which is why I'd prefer to use LD.)

000544_BROGUE_NATION_in_the_style_of_a_1950s_book_cover_cl1oz9co00003ucobpewjzwmd_s9 0_3x2

Text + partial image prompting

Hi !

In Dall-E, we can provide a partial image in addition to the text description so that the model only completes the image. See:

Capture

Can we do the same with your models? That would be awesome.
I tried to modify the LAION-400M model notebook but without much success.

how can I train with semantic

Excuse me , I want to train the ldm with semantic , but I can not find appropriate dataloaders (they may be 'landscapes.RFWTrain' and 'RFWValidation' ).Will the dataloaders be released later, or can I find they in other places?

memory problem

Hello.
How much memory do you need to run?

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 4.78 GiB already allocated; 0 bytes free; 4.82 GiB reserved in total by PyTorch)

Any solution to run with a small amount of memory?

[Question] Is it possible to gradually diffuse/transform one given real image to another using diffusion model?

Thanks for this great work. I'm quite interested in the possible applications of the (latent) diffusion model proposed in the impressive paper. Your works have shown many possible promising applications of this newly emerging generative modeling approach. However, I have another question that bothers me for several days. It would be great if you could give some advices or suggestions on this problem. The problem is actually a open one, and it's detailed below.

Question
Given an initial image (e.g. a 256x256 image with a red dog on it) as the starting image, can we use a diffusion model to diffuse/transform the initial image gradually, until it satisfies the expectation (e.g. conditioned on a text prompt of "a yellow cat"), the final image should be an image describing "a yellow cat".

Difficulty
As we know, the diffusion model assumes the initial image should be taken from the gaussian distribution. But in our situation, it is not the case. Our initial image is a real image, which I think it breaks the assumption.

I've directly tried to implement the thought in a most naive way, but it doesn't seem to work, because it generate some vague results.
It would be great if you could give some advices or suggestions on this problem. Thank you!!

The stablility of training

We use our data which only contains face.
However, when we train ldm, we find the loss does not degrease. The loss ==0.798

Epoch 5: 8%|β–Š | 107/1422 [00:57<11:36, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.800, train/loss_vlb_step=0.0366, train/loss_step=0.800, global_step=6675.0, lr_abs=0.0032, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 107/1422 [00:57<11:36, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.799, train/loss_vlb_step=0.00406, train/loss_step=0.799, global_step=6676.0, lr_abs=0.0032, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 108/1422 [00:57<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.799, train/loss_vlb_step=0.00406, train/loss_step=0.799, global_step=6676.0, lr_abs=0.0032, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 108/1422 [00:57<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.799, train/loss_vlb_step=0.0137, train/loss_step=0.799, global_step=6677.0, lr_abs=0.0032, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 109/1422 [00:58<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.799, train/loss_vlb_step=0.0137, train/loss_step=0.799, global_step=6677.0, lr_abs=0.0032, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 109/1422 [00:58<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.795, train/loss_vlb_step=0.00617, train/loss_step=0.795, global_step=6678.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 110/1422 [00:58<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.795, train/loss_vlb_step=0.00617, train/loss_step=0.795, global_step=6678.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 110/1422 [00:58<11:35, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.798, train/loss_vlb_step=0.00474, train/loss_step=0.798, global_step=6679.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 111/1422 [00:59<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.798, train/loss_vlb_step=0.00474, train/loss_step=0.798, global_step=6679.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 111/1422 [00:59<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.796, train/loss_vlb_step=0.00421, train/loss_step=0.796, global_step=6680.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 112/1422 [00:59<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.796, train/loss_vlb_step=0.00421, train/loss_step=0.796, global_step=6680.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 112/1422 [00:59<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.797, train/loss_vlb_step=0.00464, train/loss_step=0.797, global_step=6681.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 113/1422 [01:00<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.797, train/loss_vlb_step=0.00464, train/loss_step=0.797, global_step=6681.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]
Epoch 5: 8%|β–Š | 113/1422 [01:00<11:34, 1.89it/s, loss=0.798, v_num=0, train/loss_simple_step=0.794, train/loss_vlb_step=0.00373, train/loss_step=0.794, global_step=6682.0, lr_abs=0.00321, train/loss_simple_epoch=0.748, train/loss_vlb_epoch=0.0059, train/loss_epoch=0.748]

Thanks for comments

Details about training super resolution model

Hi @rromb, @ablattmann, @pesser, and thank you for making your great work publicly available.

Could you please supply the code for the class ldm.data.openimages.SuperresOpenImagesAdvancedTrain/Validation to train your model for super-resolution, as required in bsr_sr/config.yaml (see this line)?
Otherwise, some more information about how to train the SR model with datasets not included in your repository would be very helpful.

Thank you very much!

Render between two images

Is it possible to generate a sequence of images between two prompts to realize key frame animation ?
The basic idea is to render a set of frames using keyframes / multiple prompts.

If not is it possible to dump the intermediate step images?

Thank you.

Cool!

Thank you very much for releasing the new checkpoints!

Would you mind sharing more details about the training of the text2img-large model? - Did you train it on the full, unfiltered LAION-400M? For how many epochs?
What hardware did you use for how long? :)

Kind regards,
Christoph Schuhmann
Organization Lead LAION

Btw, here is our new dataset,LAION-5B :-)
https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/

conda env settings issue

Thanks for great research, I found out conda env settings error while using scripts below.

conda env create -f environment.yaml
conda activate ldm

Error Log

(ldm) ubuntu@nipa2021-19981:~/jwk/latent-diffusion$ python scripts/txt2img.py --prompt "a sunset behind a mountain range, vector image" --ddim_eta 1.0 --n_samples 1 --n_iter 1 --H 384 --W 1024 --scale 5.0  
Loading model from models/ldm/text2img-large/model.ckpt
Traceback (most recent call last):
  File "scripts/txt2img.py", line 101, in <module>
    model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt")  # TODO: check path
  File "scripts/txt2img.py", line 18, in load_model_from_config
    model = instantiate_from_config(config.model)
  File "/home/ubuntu/jwk/latent-diffusion/ldm/util.py", line 78, in instantiate_from_config
    return get_obj_from_str(config["target"])(**config.get("params", dict()))
  File "/home/ubuntu/jwk/latent-diffusion/ldm/util.py", line 86, in get_obj_from_str
    return getattr(importlib.import_module(module, package=None), cls)
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 783, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/ubuntu/jwk/latent-diffusion/ldm/models/diffusion/ddpm.py", line 12, in <module>
    import pytorch_lightning as pl
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/__init__.py", line 20, in <module>
    from pytorch_lightning import metrics  # noqa: E402
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/__init__.py", line 15, in <module>
    from pytorch_lightning.metrics.classification import (  # noqa: F401
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/__init__.py", line 14, in <module>
    from pytorch_lightning.metrics.classification.accuracy import Accuracy  # noqa: F401
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 18, in <module>
    from pytorch_lightning.metrics.utils import deprecated_metrics, void
  File "/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/pytorch_lightning/metrics/utils.py", line 22, in <module>
    from torchmetrics.utilities.data import get_num_classes as _get_num_classes
ImportError: cannot import name 'get_num_classes' from 'torchmetrics.utilities.data' (/home/ubuntu/anaconda3/envs/ldm/lib/python3.8/site-packages/torchmetrics/utilities/data.py)

# environment.yaml
name: ldm
...
dependencies:
  ...
  - pytorch=1.7.0
  - torchvision=0.8.1
  - pip:
    ...
    - pytorch-lightning==1.4.2
    ...

Problem was pytorch-lightning==1.4.2 automatically imports from torchmetrics.utilities.data import get_num_classes as _get_num_classes but that function was dropped by this PR.

So yaml file should be changed by updating pytorch & torchvision & pytorch-lightning or add explicit torchmetric version

I solved it by updating pytorch-lightning.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.