Giter Site home page Giter Site logo

Comments (7)

vincentelfving avatar vincentelfving commented on May 28, 2024 2

yes this works for my problem and the accum patch is no longer needed thanks!

from yaoblocks.jl.

Roger-luo avatar Roger-luo commented on May 28, 2024 2

I haven't got time to look into the Zygote issue (just traveling Boston), e.g the one @GiggleLiu posted

julia> Zygote.gradient(x->x.content.theta, Daggered(Rx(0.5)))
(nothing,)

let me take a second look at this a bit later to fully resolve this issue. But glad to see the main case is fixed

from yaoblocks.jl.

GiggleLiu avatar GiggleLiu commented on May 28, 2024 1

@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171

Do you recommend I put the Zygote.accum method for AbstractBlock in my own modules or is it generally applicable and will also be part of chainrules_patch.jl? (now I see it is in the test file)

We will not add this patch to YaoBlocks, because Zygote is very slow in loading and sometimes has version issues. E.g. now tests break on nightly due to using zygote in tests. In the future, we might switch a more correct implementation of constructing Tangent type for the circuit.

from yaoblocks.jl.

GiggleLiu avatar GiggleLiu commented on May 28, 2024

Thanks for the issue, this is a problem causes by returning circuit gradients as vector. I made a patch for it, can you check if it solves your issue? #171

from yaoblocks.jl.

vincentelfving avatar vincentelfving commented on May 28, 2024

@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171

Do you recommend I put the Zygote.accum method for AbstractBlock in my own modules or is it generally applicable and will also be part of chainrules_patch.jl? (now I see it is in the test file)

from yaoblocks.jl.

vincentelfving avatar vincentelfving commented on May 28, 2024

@GiggleLiu ok understood! One remaining issue with the current rrules is the following tiny modification I made to the code:

using Zygote
using Yao
using YaoBlocks

function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
    dispatch(a, parameters(a) + parameters(b))
end

N=2
psi_0 = zero_state(N)
U0 = chain(N, put(1=>Rx(0.0)), put(2=>Ry(0.0)))

function loss(theta)
    C = sum([chain(N, put(k=>Z)) for k=1:N])
    U = dispatch(U0, theta)
    psi0 = copy(psi_0)
    psi1 = apply(psi0, U)
    psi2 = apply(psi1, C)
    result = real(sum(conj(state(psi1)) .* state(psi2)))
    return result
end

theta = [1.1,2.2]
println(expect'(C, copy(psi_0) => dispatch(U0, theta))[2])
grad = Zygote.gradient(theta->loss(theta), theta)[1]
println(grad)

as compared to the opening of this issue, I only added the Zygote.accum, and moved the C, which is an Add block of chain, into the loss function code. There is even no dependence on parameters, but I get the following error:

ERROR: LoadError: Need an adjoint for constructor ChainBlock{2}. Gradient is of type Add{2}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{ChainBlock{2}, Nothing, false})(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:324
  [3] (::Zygote.var"#1768#back#224"{Zygote.Jnew{ChainBlock{2}, Nothing, false}})(Δ::Add{2})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
  [5] (::typeof(∂(ChainBlock{2})))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
  [7] (::typeof(∂(ChainBlock)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:17 [inlined]
  [9] (::typeof(∂(ChainBlock)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:48 [inlined]
 [11] (::typeof(∂(chain)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:45 [inlined]
 [13] (::typeof(∂(chain)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./none:0 [inlined]
 [15] (::typeof(∂(#201)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [16] #557
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
 [17] #4
    @ ./generator.jl:36 [inlined]
 [18] iterate
    @ ./generator.jl:47 [inlined]
 [19] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}, FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}}}}, Base.var"#4#5"{Zygote.var"#557#562"}})
    @ Base ./array.jl:678
 [20] map
    @ ./abstractarray.jl:2383 [inlined]
 [21] map_back
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
 [22] (::Zygote.var"#back#591"{Zygote.var"#map_back#561"{var"#201#202", 1, Tuple{UnitRange{Int64}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}}})(ȳ::FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:247
 [23] Pullback
    @ ~/reproducing_chainblock_bug.jl:14 [inlined]
 [24] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/reproducing_chainblock_bug.jl:25 [inlined]
 [26] (::Zygote.var"#55#56"{typeof(∂(#203))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:41
 [27] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:76
 [28] top-level scope
    @ ~/reproducing_chainblock_bug.jl:25
 [29] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [30] top-level scope
    @ REPL[7]:1
in expression starting at /reproducing_chainblock_bug.jl:25

Not only is this a current issue, but also in general I would like to differentiate such type of blocks, for example if a parameter (from the perspective of Zygote, not a Yao dispatched param per se) appears in there. Please let me know if you want me to open a new issue as a copy of this comment.

from yaoblocks.jl.

GiggleLiu avatar GiggleLiu commented on May 28, 2024

Hi, I just used the correct tangent type for gradients. It seems to solve your problem, and you do not need the patch anymore. However, there are still some issues unsettled. I am not good at debugging Zygote, if some one can help, that would be great. I posted a WIP PR in #173

from yaoblocks.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.