Comments (17)
you are correct. it is done for speed, not correctness. The computation of gradients wrt the weights of netG can be fully avoided in the backward pass if the graph is detached where it is.
from examples.
@soumith This is not true. Detaching fake
from the graph is necessary to avoid forward-passing the noise through G
when we actually update the generator. If we do not detach, then, although fake
is not needed for gradient update of D,
it will still be added to the computational graph and as a consequence of backward
pass which clears all the variables in the graph (retain_graph=False
by default), fake
won't be available when G is updated.
from examples.
Missed detach when implementing dcgan in pytorch, and it gives me this error:
RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. Please specify retain_variables=True when calling backward for the first time.
from examples.
@Einstellung
hi
I have a question, the G model gradient update
首先有三个独立的网络,鉴别器网络D,生成器网络G和源网络S
1、鉴别器网络首先 输入:合并的值特征值,输出:LogSoftmax(),损失是用1 和0 的标签, 二分类损失, 梯度更新backward()
2 输入:首先用鉴别器, 输入:生成器G的输出特征,输出:LogSoftmax(),
然后我不明白,怎么和生成器G,关联起来昵,它和鉴别器是两个独立的网络,鉴别器的梯度更新怎么和 G 联系起来昵???,
G网络的输入是target 数据 标签 ,输出是fc
D网路的输入是 G的特征以及 G与S的拼接特征。输出是LogSoftmax()
from examples.
ah yes, what I said above is only true if we also retain_graph=True
. My bad, I stand corrected.
from examples.
@plopd what you are saying doesn't make any sense to me
from examples.
Hello,
I am a little confused by this and will greatly appreciate help in understanding.
According to my understanding detach() prevents further computations from being tracked. (I suppose it also prevent previous computations from being taken into account in the backward pass?)
Either way, wouldn't you want to track the next computation, the operation of D over fake, for the backward pass of D?
If you wanted to prevent tracking of the Generator, wouldn't it make sense to detach before applying G and then restore tracking for D right at the point where detach is now called? (With requires_grad_(True)?)
Thank you
from examples.
Thanks for the quick reply @soumith!
from examples.
what @plopd has said is absolutely right. Detaching fake
from the graph is necessary and will lead to an error if not done so.
from examples.
let me tell you. The role of detach is to freeze the gradient drop. Whether it is for discriminating the network or generating the network, we update all about logD(G(z)). For the discriminant network, freezing G does not affect the overall gradient update (that is The inner function is considered to be a constant, which does not affect the outer function to find the gradient), but conversely, if D is frozen, there is no way to complete the gradient update. Therefore, we did not use the gradient of freezing D when training the generator. So, for the generator, we did calculate the gradient of D, but we didn't update the weight of D (only optimizer_g.step was written), so the discriminator will not be changed when the generator is trained. You may ask, that's why, when you train the discriminator, you need to add detach. Isn't this an extra move?
Because we freeze the gradient, we can speed up the training, so we can use it where it can be used. It is not an extra task. Then when we train the generator, because of logD(G(z)), there is no way to freeze the gradient of D, so we will not write detach here.
from examples.
from examples.
@shiyuanyin
I don't really know what you said. You should use G result to update D
from examples.
I don't really know what you said. You should use G result to update D
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
########
output = netD(fake), output is related to D model , not G model, when backward(D get the input G gradient) ,the gadient how to give the G, because is not connect
##I think the process, netG.backward(the model D grad,to input G feature)
this is right???
from examples.
detach just reduce the work that G() gradient upgrade in training step of D(), because G() will train in next step
from examples.
@soumith This is not true. Detaching
fake
from the graph is necessary to avoid forward-passing the noise throughG
when we actually update the generator. If we do not detach, then, althoughfake
is not needed for gradient update ofD,
it will still be added to the computational graph and as a consequence ofbackward
pass which clears all the variables in the graph (retain_graph=False
by default),fake
won't be available when G is updated.
If I understand correctly, then if i created a new noise input for G, there's no need for the detach() call?
from examples.
@soumith This is not true. Detaching
fake
from the graph is necessary to avoid forward-passing the noise throughG
when we actually update the generator. If we do not detach, then, althoughfake
is not needed for gradient update ofD,
it will still be added to the computational graph and as a consequence ofbackward
pass which clears all the variables in the graph (retain_graph=False
by default),fake
won't be available when G is updated.If I understand correctly, then if i created a new noise input for G, there's no need for the detach() call?
Do you mean like creating fake1=netG(noise) which is same as fake that was before disconnection. Even i have the same doubt can someone please clarify this?
from examples.
I think that's because if you don't use fake.detach()
in output = netD(fake.detach()).view(-1)
then fake
is just some middle variable in the whole computational Graph, it will track from netG()
to netD()
. and when you can optimizerD.step()
all grad information except leaf nodes are released. which means no more gradient information about netG()
in the computational Graph. then you use errG.backward()
it will cause an error
from examples.
Related Issues (20)
- add scaler.unscale_(optimizer) before clip_grad_norm_
- Can not launch DDP training using distributed/ddp-tutorial-series/multigpu.py
- multi-node DDP
- world_language_model example throws UnicodeEncodeError
- add examples/siamese_network with triplet loss example
- FSDP T5 Example not working HOT 3
- Daily CI failed
- RL Examples had bugs on current gym version
- The doc build deployment has been failing since jan HOT 1
- word_language_model/data.py - two areas of redundant code
- word_language_model/data.py - remove '<eos>'
- If I am training on a SINGLE GPU, should this "--dist-backend 'gloo'" argument be added to the command? HOT 10
- SSL Error When downloading dataset HOT 3
- Testing a C++ case with MPI failed.
- Long training time for ResNet50 on ImageNet-1k HOT 1
- Segmentation fault (core dumped) at `model(images)` for examples/imagenet/main.py HOT 1
- RuntimeError in Partialconv-master HOT 1
- Pytorch is insufficiently opinionated
- Documentation Mismatch and AssertionError in language_translation
- RuntimeError: HIP error when running ResNet-50 on PRO W7900 with PyTorch HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from examples.