Giter Site home page Giter Site logo

Comments (9)

GilesBathgate avatar GilesBathgate commented on July 20, 2024 1

I only have 1 GPU so was not using in distributed mode. So perhaps that's why. I had to make another patch for support only 1 gpu

from diffusion-gan.

GilesBathgate avatar GilesBathgate commented on July 20, 2024 1
--- a/diffusion-insgen/torch_utils/misc.py
+++ b/diffusion-insgen/torch_utils/misc.py
@@ -150,7 +150,9 @@ def copy_params_and_buffers(src_module, dst_module, require_all=False):
     for name, tensor in named_params_and_buffers(dst_module):
         assert (name in src_tensors) or (not require_all)
         if name in src_tensors:
-            tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+            requires_grad = tensor.requires_grad
+            with torch.no_grad():
+                tensor.copy_(src_tensors[name].detach())
+            tensor.requires_grad_(requires_grad)

from diffusion-gan.

Zhendong-Wang avatar Zhendong-Wang commented on July 20, 2024

Hi @GilesBathgate,

Your fix is correct. The following error for resuming is originated from InsGen: genforce/insgen#6. I tried to fix it but I didn't work it out. I guess the error is from the DHead and GHead. I don't know where they did an in-place operation. Need to wait InsGen authors to sovle this, lol...

from diffusion-gan.

Zhendong-Wang avatar Zhendong-Wang commented on July 20, 2024

Thanks a lot! Will make the change.

from diffusion-gan.

GilesBathgate avatar GilesBathgate commented on July 20, 2024

The fix should probably be this:

--- a/diffusion-insgen/training/training_loop.py
+++ b/diffusion-insgen/training/training_loop.py
@@ -154,22 +154,22 @@ def training_loop(
 
     # Construct networks.
     if rank == 0:
         print('Constructing networks...')
     common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
     G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
     D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
     G_ema = copy.deepcopy(G).eval()
 
     # Construct contrastive heads.
-    DHead = dnnlib.util.construct_class_by_name(**DHead_kwargs).train().to(device) if DHead_kwargs is not None else None
-    GHead = dnnlib.util.construct_class_by_name(**GHead_kwargs).train().to(device) if GHead_kwargs is not None else None
+    DHead = dnnlib.util.construct_class_by_name(**DHead_kwargs).train().requires_grad_(False).to(device) if DHead_kwargs is not None else None
+    GHead = dnnlib.util.construct_class_by_name(**GHead_kwargs).train().requires_grad_(False).to(device) if GHead_kwargs is not None else None
     D_ema = copy.deepcopy(D).eval()
 
     # Setup augmentation.


@@ -221,6 +224,8 @@ def training_loop(
             ddp_modules[name] = module
 
     # Distribute Heads across GPUs.
+    DHead.requires_grad_(True)
+    GHead.requires_grad_(True)
     if rank == 0:
         print(f'Distributing Contrastive Heads across {num_gpus} GPUS...')
     if num_gpus > 1:

This seems to fit the intent of the original stylegan code better.

from diffusion-gan.

Zhendong-Wang avatar Zhendong-Wang commented on July 20, 2024

@GilesBathgate Really appreciate your investigation here ๐Ÿ’ฏ. I will test the code and update accordingly.

from diffusion-gan.

Zhendong-Wang avatar Zhendong-Wang commented on July 20, 2024

This fix seems not working when saving ckpts? Do you know what could be the possible reason @GilesBathgate ?

Distributing across 2 GPUs...
Distributing Contrastive Heads across 2 GPUS...
Setting up training phases...
Setting up contrastive training phases...
Exporting sample images...
Initializing logs...
Skipping tfevents export: No module named 'tensorboard'
Training for 25000 kimg...

tick 0     kimg 0.1      time 18s          sec/tick 5.9     sec/kimg 92.60   maintenance 11.7   cpumem 4.10   gpumem 11.92  augment 0.000 T 10.0
Traceback (most recent call last):
  File "train.py", line 603, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/click/decorators.py", line 26, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "train.py", line 598, in main
    torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/zdwang/.conda/envs/difgan/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/zdwang/Research/Diffusion-GAN/diffusion-insgen/train.py", line 422, in subprocess_fn
    training_loop.training_loop(rank=rank, **args)
  File "/home/zdwang/Research/Diffusion-GAN/diffusion-insgen/training/training_loop.py", line 432, in training_loop
    misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
  File "/home/zdwang/Research/Diffusion-GAN/diffusion-insgen/torch_utils/misc.py", line 180, in check_ddp_consistency
    assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
AssertionError: DistributedDataParallel.module.mlp.0.weight

from diffusion-gan.

Zhendong-Wang avatar Zhendong-Wang commented on July 20, 2024

Thanks @GilesBathgate ! I remembered that you have one another fix, which finds some in-place operation of InsGen in its Constrastive_Head. The fix was deleted (I don't know. ...). Do you mind share it agian and I can try that one. I didn't find where it is, lol. Thanks again!

from diffusion-gan.

GilesBathgate avatar GilesBathgate commented on July 20, 2024

I proposed a change here:

tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)

Essentially disable grad before copying then re-enable. However, I prefer the fix above, which should have the same effect as grad should not be enabled before misc.copy_params_and_buffers is called, as is the case for the other modules.

I don't think either of these fixes will solve your error in misc.check_ddp_consistency

from diffusion-gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.