Giter Site home page Giter Site logo

Comments (5)

gladisor avatar gladisor commented on July 17, 2024 1

After reviewing this documentation I am still a bit confused how to update the model parameters. For example:

using Flux
using ChainRulesCore

struct Linear
    W::Matrix
    b::Vector
end

function (l::Linear)(x::Vector)
    return l.W * x .+ l.b
end

function ChainRulesCore.rrule(l::Linear, x::Vector)
    println("calling linear rrule")

    y = l(x)

    function linear_back(Δ)
        println("calling linear back")

        dW = Δ * x'
        db = Δ
        dx = (Δ' * l.W)'

        tangent = Tangent{Linear}(;W = dW, b = db)
        return tangent, dx
    end

    return y, linear_back
end

model = Flux.Chain(
    Linear(randn(2, 2), zeros(2)),
    sum)

x = randn(2)
opt = Descent(0.01)
gs = gradient(m -> m(x), model)

gs looks like this:

((layers = ((W = [-1.3399043659000172 -2.3293859097721454; -1.3399043659000172 -2.3293859097721454], b = Fill(1.0, 2)), nothing),),)

How can I update the model using an optimizer? This doesn't work:

Flux.Optimise.update!(opt, model, gs[1])

As it results in the following error:

ERROR: MethodError: no method matching similar(::Tuple{Linear, typeof(sum)}, ::Type{Tuple{NamedTuple{(:W, :b), Tuple{Matrix{Float64}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}, Nothing}})

from chainrulescore.jl.

CarloLucibello avatar CarloLucibello commented on July 17, 2024

Why you think it isn't working?
I changed slightly the example to show clearly that it is working:

using Flux
using ChainRulesCore

struct Foo
    A::Matrix
    c::Float64
end

Flux.@functor Foo

function foo_mul(foo::Foo, b::AbstractArray)
    return foo.A * b
end

function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
    println("calling foo rrule")
    y = foo_mul(foo, b)

    function foo_mul_pullback(ȳ)

        f̄ = NoTangent()
        f̄oo = @thunk(Tangent{Foo}(; A=fill!(similar(foo.A), 1), c=ZeroTangent()))
        b̄ = @thunk(foo.A' * ȳ)

        return f̄, f̄oo, b̄
    end

    return y, foo_mul_pullback
end

foo = Foo(randn(2, 2), 1.0)
b = randn(2)

grad = gradient(foo -> sum(foo_mul(foo, b)), foo)[1]
# calling foo rrule
# (A = [1.0 1.0; 1.0 1.0], c = nothing)

from chainrulescore.jl.

gladisor avatar gladisor commented on July 17, 2024

Hi Carlo,

Thanks for your response. I see my mistake now. I was trying to extract the parameters of foo but that is unnecessary.

How can I use this to now update foo with an optimizer?

This code is not working properly:

gs = gradient(foo -> sum(foo_mul(foo, b)), foo)[1]

opt = Descent(0.01)
Flux.Optimise.update!(opt, foo, gs)

from chainrulescore.jl.

CarloLucibello avatar CarloLucibello commented on July 17, 2024

http://fluxml.ai/Flux.jl/stable/training/training/

from chainrulescore.jl.

gladisor avatar gladisor commented on July 17, 2024

Nevermind, I figured it out. Will post a full solution later.

from chainrulescore.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.