Comments (7)
yes this works for my problem and the accum patch is no longer needed thanks!
from yaoblocks.jl.
I haven't got time to look into the Zygote issue (just traveling Boston), e.g the one @GiggleLiu posted
julia> Zygote.gradient(x->x.content.theta, Daggered(Rx(0.5)))
(nothing,)
let me take a second look at this a bit later to fully resolve this issue. But glad to see the main case is fixed
from yaoblocks.jl.
@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171
Do you recommend I put the
Zygote.accum
method forAbstractBlock
in my own modules or is it generally applicable and will also be part ofchainrules_patch.jl
? (now I see it is in the test file)
We will not add this patch to YaoBlocks, because Zygote is very slow in loading and sometimes has version issues. E.g. now tests break on nightly due to using zygote in tests. In the future, we might switch a more correct implementation of constructing Tangent type for the circuit.
from yaoblocks.jl.
Thanks for the issue, this is a problem causes by returning circuit gradients as vector. I made a patch for it, can you check if it solves your issue? #171
from yaoblocks.jl.
@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171
Do you recommend I put the Zygote.accum
method for AbstractBlock
in my own modules or is it generally applicable and will also be part of chainrules_patch.jl
? (now I see it is in the test file)
from yaoblocks.jl.
@GiggleLiu ok understood! One remaining issue with the current rrules is the following tiny modification I made to the code:
using Zygote
using Yao
using YaoBlocks
function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
dispatch(a, parameters(a) + parameters(b))
end
N=2
psi_0 = zero_state(N)
U0 = chain(N, put(1=>Rx(0.0)), put(2=>Ry(0.0)))
function loss(theta)
C = sum([chain(N, put(k=>Z)) for k=1:N])
U = dispatch(U0, theta)
psi0 = copy(psi_0)
psi1 = apply(psi0, U)
psi2 = apply(psi1, C)
result = real(sum(conj(state(psi1)) .* state(psi2)))
return result
end
theta = [1.1,2.2]
println(expect'(C, copy(psi_0) => dispatch(U0, theta))[2])
grad = Zygote.gradient(theta->loss(theta), theta)[1]
println(grad)
as compared to the opening of this issue, I only added the Zygote.accum, and moved the C, which is an Add block of chain, into the loss function code. There is even no dependence on parameters, but I get the following error:
ERROR: LoadError: Need an adjoint for constructor ChainBlock{2}. Gradient is of type Add{2}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{ChainBlock{2}, Nothing, false})(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:324
[3] (::Zygote.var"#1768#back#224"{Zygote.Jnew{ChainBlock{2}, Nothing, false}})(Δ::Add{2})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
[5] (::typeof(∂(ChainBlock{2})))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
[7] (::typeof(∂(ChainBlock)))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:17 [inlined]
[9] (::typeof(∂(ChainBlock)))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:48 [inlined]
[11] (::typeof(∂(chain)))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:45 [inlined]
[13] (::typeof(∂(chain)))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[14] Pullback
@ ./none:0 [inlined]
[15] (::typeof(∂(#201)))(Δ::Add{2})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[16] #557
@ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
[17] #4
@ ./generator.jl:36 [inlined]
[18] iterate
@ ./generator.jl:47 [inlined]
[19] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}, FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}}}}, Base.var"#4#5"{Zygote.var"#557#562"}})
@ Base ./array.jl:678
[20] map
@ ./abstractarray.jl:2383 [inlined]
[21] map_back
@ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
[22] (::Zygote.var"#back#591"{Zygote.var"#map_back#561"{var"#201#202", 1, Tuple{UnitRange{Int64}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}}})(ȳ::FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:247
[23] Pullback
@ ~/reproducing_chainblock_bug.jl:14 [inlined]
[24] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
[25] Pullback
@ ~/reproducing_chainblock_bug.jl:25 [inlined]
[26] (::Zygote.var"#55#56"{typeof(∂(#203))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:41
[27] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:76
[28] top-level scope
@ ~/reproducing_chainblock_bug.jl:25
[29] include(fname::String)
@ Base.MainInclude ./client.jl:444
[30] top-level scope
@ REPL[7]:1
in expression starting at /reproducing_chainblock_bug.jl:25
Not only is this a current issue, but also in general I would like to differentiate such type of blocks, for example if a parameter (from the perspective of Zygote, not a Yao dispatched param per se) appears in there. Please let me know if you want me to open a new issue as a copy of this comment.
from yaoblocks.jl.
Hi, I just used the correct tangent type for gradients. It seems to solve your problem, and you do not need the patch anymore. However, there are still some issues unsettled. I am not good at debugging Zygote, if some one can help, that would be great. I posted a WIP PR in #173
from yaoblocks.jl.
Related Issues (20)
- incorrect error msg HOT 1
- Make measurement result change inplace HOT 1
- rm mathgate etc.
- convert Rx(0) to floating point HOT 1
- StaticArrays warning
- move faithful gradient & extend sampling interface for expect
- register HOT 23
- add docs to constant gates
- optimize `repeat(n, H)` by Wlash-Hadamard transform
- doc improvement of expect' HOT 2
- The different behaviors between `expect` and `expect'` when inputting a pair of register and circuit
- return the expectation and its gridient in the same function exp' HOT 1
- TagBot trigger issue HOT 11
- `mat_back!` for custom blocks HOT 1
- What does `adjcunmat` do? HOT 6
- integration with IBMQClient
- define my custom gate but cannot use mat function HOT 6
- how can I implement two blocks sharing same parameters? HOT 6
- Zygote patch not working when involving non-unitary blocks HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from yaoblocks.jl.