Giter Site home page Giter Site logo

deepcorrections.jl's Introduction

DeepCorrections

Build Status

Coverage Status

codecov.io

This package implements the deep correction method [1] for solving reinforcement learning problems. The user should define the problem according to the POMDPs.jl interface.

[1] M. Bouton, K. Julian, A. Nakhaei, K. Fujimura, and M. J. Kochenderfer, “Utility decomposition with deep corrections for scalable planning under uncertainty,” in International Conference on Autonomous Agents and Multiagent Systems (AAMAS), 2018.

Installation

using Pkg
Pkg.add("POMDPs")
POMDPs.add_registry() # to get DeepQLearning, and RLInterface
Pkg.add(PackageSpec(url="https://github.com/sisl/DeepCorrections.jl"))

Usage

using POMDPs
using DeepCorrections
using Flux # for model definition
using DeepQLearning # for underlying DQN solver
using POMDPModels # for gridworld

mdp = SimpleGridWorld()

function my_low_fidelity_values(problem::SimpleGridWorld, s)
    return ones(n_actions(problem)) # dummy example, should return an action value vector 
end

model = Chain(Dense(2, 32, relu), Dense(32, n_actions(mdp))) # input is 2 dimensional, x,y positions in grid world
dqn_solver = DeepQLearningSolver(qnetwork = model, verbose=true) # see DQN docs for all the parameters
solver = DeepCorrectionSolver(dqn = dqn_solver,
                              lowfi_values = my_low_fidelity_values)

policy = solve(solver, problem)

Documentation

The type DeepCorrectionSolver relies on the DeepQLearningSolver type defined in DeepQLearning.jl. The deep correction solver supports all the options available in for the DeepQLearningSolver.

solve returns a DeepCorrectionPolicy object. It can be used like any policy in the POMDPs.jl interface.

Low fidelity value estimation:

To provide the low fidelity value function to the solver the user can use the lowfi_values option specify when initializing the solver. It can be a function or a policy. If this is a function f, f(mdp, s) will be called to estimate the value. If this is a policy, actionvalues(policy, s) will be called. See the documentation in POMDPPolicies for more details on actionvalues. The output should be a vector of size n_actions(mdp). The actions are assumed to be ordered according to the function action_index implemented by the problem writer.

Correction method:

Two default correction methods are available:

  • additive correction: Q_lo(s, a) + delta(s, a), where Q_lo is the result of lowfi_values and delta is the correction network.
  • multiplicative correction: Q_lo(s, a)delta(s, a)

An additional constant weight can be used in the correction method using the option correction_weight in the solver. The user can write its own correction method via the correction_method option. It can be a function or an object. If this is a function f, f(mdp, q_lo, q_corr, correction_weight) will be called to estimate the value. If this is an object o, correction(o, mdp, q_lo, q_corr, correction_weight) will be called.

The underlying implementation relies on tensorflow and static graphs, the correction method implemented must support tensor inputs and be tensorflow friendly. The signature should look like this:

    multiplicative_correction(problem::Union{POMDP, MDP}, q_lo::Q, q_corr::Q, weight::Float64) where Q <:Union{Array{Float64}, Tensor}

deepcorrections.jl's People

Contributors

maximebouton avatar

Stargazers

Ross Alexander avatar Gábor Nagymajtényi avatar  avatar Peggy Wang avatar

Watchers

James Cloos avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.