Giter Site home page Giter Site logo

Comments (4)

lubin-liu avatar lubin-liu commented on July 20, 2024

If I understood the issue correctly, I think the cause may be the gaussian white noise that is injected into the forward pass of ConstrainedNoFeedbackESN as specified by noise_param. From `Training FORCE with Zebrafish neural data.ipynb', if you re-run the two prediction cells from the image below multiple times, sometimes the model looks like it hasn't learned (very high error), and that the output is not reproducible despite the inputs and weights having not changed. You can set the noise_param to (0.0, 0.0) in the layer definition which would zero out the noise during the forward pass. Edit: I think setting the noise to 0.0 doesn't work for this layer with the default initial states (neuron firing rates all zero), so you would need to figure out an initial state that allows learning to occur, which from past experimentations was tricky.

From `Training FORCE with Zebrafish neural data.ipynb',

from tension.

dmnburrows avatar dmnburrows commented on July 20, 2024

Thanks for the prompt response! I have looked into the noise as a potential source of the problem - it is definitely true that because of some noise I cannot reproduce exactly the same trajectories when I re-run the model.predict on the same model object after learning (see image below, each trajectory is a different run of model.predict on the same esn_layer after one instance of training - 50 epochs of model.fit)

Screenshot 2023-03-23 at 09 57 14

However, I don't think this is the main source of my problem. If I run model.fit and save the weights, then I start a brand new model object (with the same input parameters as my first model), I want to be able to capture approximately the same dynamics using model.predict, without having to re-fit the model.

As far as I am aware I need to run model.fit to initialise the new model - my logic was to compile the new model (using the same parameters as the old trained model), then run the model.fit() for a short time period (eg. 1 epoch) just so that the model has been initialised. I then have loaded the recurrent and input weights from the original model into my new model and ran model.predict. Here I expected that, seeing as I am using the weights from the learned model, model.predict should just recreate the roughly the same dynamics as in the original model after learning - however i found that the dynamics look vastly different.

Im starting to think that I am misunderstanding model.fit and model.predict - I thought all the weight learning happens during model.fit() and then model.predict() should just generate the dynamics using the learnt weights. However, when I compile my new model and use model.fit() for a few epochs (just to initialise the model so I can use model.predict) and then load in the learnt weights from the original model, I find that the model.predict changes according to the number of epochs - as the epochs increases the dynamics also change (see below).

Screenshot 2023-03-23 at 10 37 29

Why should model.predict() in the new model be affected by the number of epochs in model.fit() if I am re-loading the weights from the original model? Is there learning occuring during model.predict? or perhaps it is not correctly using the newly assigned weights (I have used model.load_weights(checkpoint_path), and then checked esn_layer.recurrent_kernel and esn_layer.input_kernel and it seems to be correct)?

Thanks for the help! let me know if seeing some of my code would be helpful

from tension.

lubin-liu avatar lubin-liu commented on July 20, 2024

#From your description about it seems like when you created the new model after loading in the weights, the new RNN layer passed into this new model did not use the final state from the previous model's layer (neuron pre-activation firing rates)? If so, then when you called model.fit the second time, the neuron firing rates will be starting from all zeroes (or randomly initialized, depending on the pre-defined get_initial_state method), leading to divergence in results. The RNN's layer states is technically not a weight so I don't think it's saved during model.save_weights, so it would have to be saved separately. Below is just a simplified version of the Zebrafish, with some dummy initial_a (random Gaussian) with noise removed and ran for 5 epochs:

capture1

I re-ran the same cell for only 2 epochs and saved the weights (error in the first two epoch matches the above):

capture2

The states of the RNN layers can be accessed via model.force_layer.states, the states being a tuple of elements shown here for ConstrainedNoFeedbackESN.

To test defining a new model with a new layer with the saved weights from the previous, the first state of the previous layer should be passed into the initial_a parameter of the new layer definition as below. You can actually manually build the model by calling the build method with the correct input shape (as opposed to fitting it again for 1 epoch), then load in the weights and either do fit or predict. When I did fit again for 3 epochs, the error matched the error in the last 3 epochs of the first figure. The changes are the 3 lines indicated below.

capture3

from tension.

dmnburrows avatar dmnburrows commented on July 20, 2024

This fixed it! Thank you so much :D

from tension.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.