Giter Site home page Giter Site logo

Comments (3)

kvnlxm avatar kvnlxm commented on June 4, 2024

Hi Alex,
thank you for your kind words and your interest in VAME.

Loading a pretrained model is a little tricky up to this point and we are planning to make this function easier with our next update. To use a pretrained model you have to do the following steps:

  1. Set the arguments pretrained_weights=True and pretrained_model="vame_model.pkl" in the vame.rnn() function to train the network.
  2. As you have already mentioned you have to copy your model into the pretrained_model folder. Make sure that the name of the pretrained model matches the name of the argument string in step 1.
  3. You have to go into your config.yaml and set the hyperparameter kl_start=0 and annealtime=1

Now, if you run the function vame.rnn() you should see something like this in your console:
Train RNN model!
Using CUDA
GPU active: True
GPU used: Quadro RTX 5000
Latent Dimensions: 10, Beta: 1, lr: 0.0008
Loading pretrained Model: vame_model
Initialize train data. Datapoints 1306914
Epoch: 1. loss: 65.9462

The model prints that its using a pretrained model.

To answer your second question, yes you can! However, make sure that you used the same pose estimation model on this data as well.

Best,
Kevin

from vame.

alexcwsmith avatar alexcwsmith commented on June 4, 2024

Got it, thanks! I am wondering what the impact would be if I ran rnn_model with pretrained_weights=True, and an incorrect argument for pretrained_model= ?

I trained a model overnight last night, and passed the full path to the .pkl file as the argument for pretrained_model. It ran for 250 iterations, which I think may have overfit, as the test MSE hadn't changed much for ~100 iterations. I suppose I need to look into why it didn't stop itself, on a previous model it stopped around iteration 175 and said the model had converged. When I evaluated this model I got these results:

Future_Reconstruction

I haven't seen flat lines like that on any previous model, do you have any thoughts on that? Does that indicate overfitting, or could it have to do with training this model with a bad pretrained_model argument? Or is the time window too short? I am a bit confused by this plot in general... i've gathered that the X axis is the time window, but is the Y showing the latent state for each of my 12 features? I guess i'm confused by what the Y axis is?

Thanks again!
Alex

from vame.

kvnlxm avatar kvnlxm commented on June 4, 2024

Hey,

I am closing this issue as with the new release, resuming with pretrained weights is now optimized and can be easily done. Definetely check out the new version but also read the wiki, as we have changed some ways of how to specify your arguments for the vame functions.

Cheers,
Kevin

from vame.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.