Giter Site home page Giter Site logo

fancompute / neuroptica Goto Github PK

View Code? Open in Web Editor NEW
199.0 21.0 39.0 7.39 MB

Flexible simulation package for optical neural networks

Home Page: https://doi.org/10.1109/JSTQE.2019.2930455

License: MIT License

Python 100.00%
photonics nanophotonics neural-network optics machine-learning

neuroptica's Issues

loss function nan when using AbsSquared activation

running this mode

model_linear = neu.Sequential([
    neu.ClementsLayer(N),
    neu.Activation(neu.AbsSquared(N)),
    neu.DropMask(N, keep_ports=range(N_classes))
])

losses = neu.InSituAdam(model_linear, neu.CategoricalCrossEntropy, step_size=step_size).fit(x_train_flattened, y_train_onehot, epochs=n_epochs, batch_size=batch_size)

gives the warning:

  X_softmax = np.exp(X) / np.sum(np.exp(X), axis=0)
../neuroptica/neuroptica/losses.py:45: RuntimeWarning: invalid value encountered in true_divide
  X_softmax = np.exp(X) / np.sum(np.exp(X), axis=0)

And loss function is nan

When changing AbsSquared to Abs it works fine.

Sketchy Loss Functions

When I train simple linear models, the loss function oscillates wildly. For example, using inSituAdam:

image

model_linear = neu.Sequential([
    neu.ClementsLayer(N),
    neu.Activation(neu.Abs(N)),
    neu.DropMask(N, keep_ports=range(N_classes))
])

losses = neu.InSituAdam(model_linear, neu.CategoricalCrossEntropy, step_size=step_size).fit(x_train_flattened, y_train_onehot, epochs=n_epochs, batch_size=batch_size)

This may be a sign that the gradients are incorrect. Should double check.

IndexError when trying to train a network with N = 2

Not sure if I am doing something stupid, but I get the following error when trying to train a mesh of dimension N = 2.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-64-b9c14e8893da> in <module>
     23 Y_formatted = Y.T
     24 
---> 25 losses = neu.InSituAdam(model, neu.CategoricalCrossEntropy, step_size=0.005).fit(X_formatted, Y_formatted, epochs=1000, batch_size=32)
     26 
     27 plt.plot(losses)

~/drive/Research/Projects/ONN/neuroptica/neuroptica/optimizers.py in fit(self, data, labels, epochs, batch_size, show_progress)
    169 
    170                 # Compute the backpropagated signals for the model
--> 171                 deltas = self.model.backward_pass(d_loss)
    172                 delta_prev = d_loss  # backprop signal to send in the final layer
    173 

~/drive/Research/Projects/ONN/neuroptica/neuroptica/models.py in backward_pass(self, d_loss)
     59         gradients = {"output": d_loss}
     60         for layer in reversed(self.layers):
---> 61             backprop_signal = layer.backward_pass(backprop_signal)
     62             gradients[layer.__name__] = backprop_signal
     63         return gradients

~/drive/Research/Projects/ONN/neuroptica/neuroptica/layers.py in backward_pass(self, delta)
     50         delta_back = np.zeros((self.input_size, n_samples), dtype=NP_COMPLEX)
     51         for i in range(n_features):
---> 52             delta_back[self.ports[i]] = delta[i]
     53         return delta_back
     54 

IndexError: list index out of range

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.