Giter Site home page Giter Site logo

sisl / ceem Goto Github PK

View Code? Open in Web Editor NEW
8.0 17.0 1.0 11.78 MB

Certainty-Equivalent Expectation Maximization: a scalable algorithm for system identification of partially observed systems

Home Page: https://sites.google.com/stanford.edu/ceem

License: MIT License

Python 99.59% MATLAB 0.41%
machine-learning system-identification partially-observable-environment robotics

ceem's Introduction

CE-EM

Official implementation of the the algorithm CE-EM and baseline Particle EM from "Scalable Identification of Partially Observed Systems with Certainty-Equivalent EM".

Website

Usage

Ensure you are using at least Python 3.6

pip install CEEM

Run python -m pytest to ensure everything works.

A Jupyter notebook demonstrating usage can be found in the examples subfolder.

Code overview

  • ceem/dynamics.py defines the system API used by the CEEM algorithm.
  • ceem/systems/*.py define various systems used in the experiments
  • ceem/ceem.py contains the CEEM algorithm.
  • ceem/smoother.py defines different smoothing routines used by the CEEM algorithm in the smoothing step.
  • ceem/learner.py defines different learning routines used by the CEEM algorithm in the learning step.
  • ceem/opt_criteria.py defines different optimization criteria used by the CEEM algorithm.
  • ceem/particleem.py implements Particle EM

Experiments

Lorenz

Unbiased Estimation in Deterministic Settings

To regenerate the data in data/lorenz/bias_experiment run:

python experiments/lorenz/bias_experiment.py

To generate Table 1 run:

python experiments/lorenz/plotting/process_bias.py

Comparison to Particle Based Methods

To regenerate the data in data/lorenz/comp run:

python experiments/lorenz/comp_pem.py
python experiments/lorenz/comp_ceem.py

To generate Figure 2 run:

python experiments/lorenz/plotting/process_comp.py

Convergence of CE-EM on High Dimensional Problems

To regenerate data in data/lorenz/convergence_experiment run:

python experiments/lorenz/convergence_experiment_pem.py
python experiments/lorenz/convergence_experiment_ceem.py

To generate Figure 3 run:

python experiments/lorenz/plotting/process_convergence.py

Helicopter

The following are scripts for training models in Section 4.2. Pretrained models are provided in the pretrained_models folder.

Data download

The dataset used in our experiments can be downloaded by running:

wget 'https://zenodo.org/record/3662987/files/datasets.zip?download=1' -O datasets.zip
unzip datasets.zip

Baselines

Naive

Run the experiment with default parameters:

python experiments/heli/baselines.py --model naive

H25

Run the experiment with default parameters:

python experiments/heli/baselines.py --model H25
cp data/h25/best_net.th trained_models/h25.th

SID

Prepare the data first for residual training:

cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py

Ensure you have MATLAB with the System Identification Toolbox installed then run from within MATLAB:

run_n4sid.m

LSTM

python experiments/heli/train_lstm.py
cp data/heli_lstm/ckpts/best_model.th trained_models/lstm.th

NL (Ours)

Prepare the data first for residual training:

cp data/naive/best_net.th trained_models/naive_baseline.th
python experiments/heli/prepare_residual_dataset.py

Run the experiment with default parameters:

python experiments/heli/ceemnl.py 

Move the best model to trained_models

cp data/NLobsLdyn/ckpts/best_model.th trained_models/NL_model.th

Evaluating and plotting test trajectories

First evaluate the models (uses pretrained by default) by running:

python experiments/heli/evaluate_models.py
python experiments/heli/plotting/plotbar.py

Then plot the n th trajectory in the test set by running:

python experiments/heli/plotting/plot_trajectories.py --trajectory 9

To plot the circular acceleration prediction (instead of horizontal) on the n th trajectory in the test set:

python experiments/heli/plotting/plot_trajectories.py --trajectory 9 --moments

ceem's People

Contributors

kunalmenda avatar rejuvyesh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

jinxcrazy

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.