Giter Site home page Giter Site logo

scottgigante / m-phate Goto Github PK

View Code? Open in Web Editor NEW
58.0 8.0 8.0 137.75 MB

Multislice PHATE for tensor embeddings

Home Page: https://arxiv.org/abs/1908.02831

License: GNU General Public License v3.0

Python 97.91% Shell 2.09%
visualization deep-learning neural-networks interpretable-deep-learning

m-phate's Introduction

M-PHATE

Latest PyPi version Travis CI Build Coverage Status arXiv Preprint Twitter GitHub stars Video abstract

Demonstration M-PHATE plot

Multislice PHATE (M-PHATE) is a dimensionality reduction algorithm for the visualization of time-evolving data. To learn more about M-PHATE, you can read our preprint on arXiv in which we apply it to the evolution of neural networks over the course of training. Above we show a demonstration of M-PHATE applied to a 3-layer MLP over 300 epochs of training, colored by epoch (left), hidden layer (center) and the digit label that most strongly activates each hidden unit (right). Below, you see the same network with dropout applied in training embedded in 3D, also colored by most active unit.

Table of Contents

3D rotating gif

How it works

Multislice PHATE (M-PHATE) combines a novel multislice kernel construction with the PHATE visualization. Our kernel captures the dynamics of an evolving graph structure, that when when visualized, gives unique intuition about the evolution of a system; in our preprint, we show this applied to a neural network over the course of training and re-training. We compare M-PHATE to other dimensionality reduction techniques, showing that the combined construction of the multislice kernel and the use of PHATE provide significant improvements to visualization. In two vignettes, we demonstrate the use M-PHATE on established training tasks and learning methods in continual learning, and in regularization techniques commonly used to improve generalization performance.

The multislice kernel used in M-PHATE consists of building graphs over time slices of data (e.g. epochs in neural network training) and then connecting these slices by connecting each point to itself over time, weighted by its similarity. The result is a highly sparse, structured kernel which provides insight into the evolving structure of the data.

For more details, check out our NeurIPS publication, read the tweetorial or have a look at our poster.

Example of multislice graph

Example of multislice kernel

Installation

Install from pypi

pip install --user m-phate

Install from source

pip install --user git+https://github.com/scottgigante/m-phate.git

Usage

Basic usage example

Below we apply M-PHATE to simulated data of 50 points undergoing random motion.

import numpy as np
import m_phate
import scprep

# create fake data
n_time_steps = 100
n_points = 50
n_dim = 25
np.random.seed(42)
data = np.cumsum(np.random.normal(0, 1, (n_time_steps, n_points, n_dim)), axis=0)

# embedding
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(data)

# plot
time = np.repeat(np.arange(n_time_steps), n_points)
scprep.plot.scatter2d(m_phate_data, c=time, ticks=False, label_prefix="M-PHATE")

Example embedding

Network training

To apply M-PHATE to neural networks, we provide helper classes to store the samples from the network during training. In order to use these, you must install tensorflow and keras.

import numpy as np

import keras
import scprep

import m_phate
import m_phate.train
import m_phate.data

# load data
x_train, x_test, y_train, y_test = m_phate.data.load_mnist()

# select trace examples
trace_idx = [np.random.choice(np.argwhere(y_test[:, i] == 1).flatten(),
                              10, replace=False)
             for i in range(10)]
trace_data = x_test[np.concatenate(trace_idx)]

# build neural network
lrelu = keras.layers.LeakyReLU(alpha=0.1)
inputs = keras.layers.Input(
    shape=(x_train.shape[1],), dtype='float32', name='inputs')
h1 = keras.layers.Dense(128, activation=lrelu, name='h1')(inputs)
h2 = keras.layers.Dense(64, activation=lrelu, name='h2')(h1)
h3 = keras.layers.Dense(128, activation=lrelu, name='h3')(h2)
outputs = keras.layers.Dense(10, activation='softmax', name='output_all')(h3)

# build trace model helper
model_trace = keras.models.Model(inputs=inputs, outputs=[h1, h2, h3])
trace = m_phate.train.TraceHistory(trace_data, model_trace)

# compile network
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy',
              metrics=['categorical_accuracy', 'categorical_crossentropy'])

# train network
model.fit(x_train, y_train, batch_size=128, epochs=200,
          verbose=1, callbacks=[trace],
          validation_data=(x_test,
                           y_test))

# extract trace data
trace_data = np.array(trace.trace)
epoch = np.repeat(np.arange(trace_data.shape[0]), trace_data.shape[1])

# apply M-PHATE
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(trace_data)

# plot the result
scprep.plot.scatter2d(m_phate_data, c=epoch, ticks=False,
                      label_prefix="M-PHATE")

Example notebooks

For detailed examples, see our sample notebooks in keras and tensorflow in examples:

Parameter tuning

The key to tuning the parameters of M-PHATE is essentially balancing the tradeoff between interslice connectivity and intraslice connectivity. This is primarily achieved with interslice_knn and intraslice_knn. You can see an example of the effects of parameter tuning in this notebook.

Figure reproduction

We provide scripts to reproduce all of the empirical figures in the preprint.

To run them:

git clone https://github.com/scottgigante/m-phate
cd m-phate
pip install --user .

# change this if you want to store the data elsewhere
DATA_DIR=~/data/checkpoints/m_phate

# choose between cifar and mnist
DATASET="mnist"
EXTRA_ARGS="--dataset ${DATASET}"

# remove to use validation data
EXTRA_ARGS="${EXTRA_ARGS} --sample-train-data"

chmod +x scripts/generalization/generalization_train.sh
chmod +x scripts/task_switching/classifier_mnist_task_switch_train.sh

./scripts/generalization/generalization_train.sh "${DATA_DIR}" "${EXTRA_ARGS}"
./scripts/task_switching/classifier_mnist_task_switch_train.sh "${DATA_DIR}" "${EXTRA_ARGS}"

python scripts/demonstration_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/comparison_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/generalization_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/task_switch_plot.py "${DATA_DIR}" "${DATASET}"

TODO

  • Provide support for PyTorch
  • Notebook examples for:
    • Classification, pytorch
    • Autoencoder, pytorch
  • Build readthedocs page

Help

If you have any questions, please feel free to open an issue.

m-phate's People

Contributors

scottgigante avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

m-phate's Issues

Inf/NaN encountered during landmark calculation

I'm trying to use M-PHATE to visualize changes going on in a neural network during a continual learning method - the network is quite a bit larger than the examples (2400 hidden neurons total), and the transformation crashed during the landmark operator calculation. According to the traceback, there were Inf or NaN values encountered. Earlier in the graph/diffusion operator calculation, there was also a warning that invalid values were encountered.

I don't see any invalid data or discrepancies in the input tensor which would easily explain the issue. I'm not very familiar with the graphical operations going on under the hood, is it possible that there is a parameter which needs to change for the larger number of hidden neurons or an issue in the code? The input tensor of activations I used is available here.

Calculating M-PHATE...
  Calculating multislice kernel...
  Calculated multislice kernel in 233.67 seconds.
  Calculating graph and diffusion operator...

/home/wolin/anaconda3/envs/mphate/lib/python3.7/site-packages/scipy/sparse/compressed.py:213: RuntimeWarning: invalid value encountered in less
  res = self._with_data(op(self.data, other), copy=True)
/home/wolin/anaconda3/envs/mphate/lib/python3.7/site-packages/graphtools/graphs.py:1037: RuntimeWarning: invalid value encountered in less
  K.data[K.data < self.thresh] = 0

    Calculating landmark operator...
      Calculating SVD...
      Calculated SVD in 41.95 seconds.
    Calculated landmark operator in 41.95 seconds.
  Calculated graph and diffusion operator in 91.94 seconds.
Calculated M-PHATE in 326.34 seconds.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/anaconda3/envs/mphate/lib/python3.7/site-packages/graphtools/graphs.py in landmark_op(self)
    588         try:
--> 589             return self._landmark_op
    590         except AttributeError:

AttributeError: 'TraditionalLandmarkGraph' object has no attribute '_landmark_op'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-141-aaf34eb39e68> in <module>
      1 # calculate M-PHATE
      2 m_phate_op = m_phate.M_PHATE()
----> 3 m_phate_data = m_phate_op.fit_transform(activations)

~/anaconda3/envs/mphate/lib/python3.7/site-packages/m_phate/m_phate.py in fit_transform(self, X, **kwargs)
    191         """
    192         with _logger.task('M-PHATE'):
--> 193             self.fit(X)
    194             embedding = self.transform(**kwargs)
    195         return embedding

~/anaconda3/envs/mphate/lib/python3.7/site-packages/m_phate/m_phate.py in fit(self, X)
    167                 random_state=self.random_state,
    168                 **(self.kwargs))
--> 169             self.diff_op
    170         result = super().fit(self.graph)
    171         return result

~/anaconda3/envs/mphate/lib/python3.7/site-packages/phate/phate.py in diff_op(self)
    274         if self.graph is not None:
    275             if isinstance(self.graph, graphtools.graphs.LandmarkGraph):
--> 276                 diff_op = self.graph.landmark_op
    277             else:
    278                 diff_op = self.graph.diff_op

~/anaconda3/envs/mphate/lib/python3.7/site-packages/graphtools/graphs.py in landmark_op(self)
    589             return self._landmark_op
    590         except AttributeError:
--> 591             self.build_landmark_op()
    592             return self._landmark_op
    593 

~/anaconda3/envs/mphate/lib/python3.7/site-packages/graphtools/graphs.py in build_landmark_op(self)
    659                     self.diff_aff,
    660                     n_components=self.n_svd,
--> 661                     random_state=self.random_state,
    662                 )
    663             with _logger.task("KMeans"):

~/anaconda3/envs/mphate/lib/python3.7/site-packages/sklearn/utils/extmath.py in randomized_svd(M, n_components, n_oversamples, n_iter, power_iteration_normalizer, transpose, flip_sign, random_state)
    346 
    347     Q = randomized_range_finder(M, n_random, n_iter,
--> 348                                 power_iteration_normalizer, random_state)
    349 
    350     # project M to the (k + p) dimensional space using the basis vectors

~/anaconda3/envs/mphate/lib/python3.7/site-packages/sklearn/utils/extmath.py in randomized_range_finder(A, size, n_iter, power_iteration_normalizer, random_state)
    230             Q = safe_sparse_dot(A.T, Q)
    231         elif power_iteration_normalizer == 'LU':
--> 232             Q, _ = linalg.lu(safe_sparse_dot(A, Q), permute_l=True)
    233             Q, _ = linalg.lu(safe_sparse_dot(A.T, Q), permute_l=True)
    234         elif power_iteration_normalizer == 'QR':

~/anaconda3/envs/mphate/lib/python3.7/site-packages/scipy/linalg/decomp_lu.py in lu(a, permute_l, overwrite_a, check_finite)
    208     """
    209     if check_finite:
--> 210         a1 = asarray_chkfinite(a)
    211     else:
    212         a1 = asarray(a)

~/anaconda3/envs/mphate/lib/python3.7/site-packages/numpy/lib/function_base.py in asarray_chkfinite(a, dtype, order)
    497     if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
    498         raise ValueError(
--> 499             "array must not contain infs or NaNs")
    500     return a
    501 

ValueError: array must not contain infs or NaNs

PyTorch support

Hi! Thanks for the great project! I am interested in using your package w/ PyTorch networks.

Do you have pointers? I guess the key piece is the _History object?

If it's easy to do, maybe I can help with the PR :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.