term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y)).mean()
PRIOR = - (term_a + term_b) * self.gamma
"-(term_a + term_b)" is the loss of Discriminator, and “term_b” is the loss of encoder( similar as generator of gan )
In the code you only backward Discriminator's loss(part of prior distribution), and there is no backward of the loss that belongs to the encoder in the prior distribution.
loss.backward() // loss = global+local + prior , prior =-(term_a+term_b)
optim.step()
loss_optim.step()
term_a = torch.log(self.prior_d(prior)).mean()
term_b = torch.log(1.0 - self.prior_d(y.detach())).mean() // y should be detach
PRIOR = - (term_a + term_b) * self.gamma
encoder_loss_for_p = term_b
.............
loss.backward() // loss = global+local + prior , prior =-(term_a+term_b)
optim.step() //update the gradient from global+local but no prior
loss_optim.step()
encoder_loss_for_p.backward() //optim the encoder for Adversarial
optim.step()