Giter Site home page Giter Site logo

cyyeung1234 / explainability_for_photonics Goto Github PK

View Code? Open in Web Editor NEW

This project forked from raman-lab-ucla/explainability_for_photonics

2.0 0.0 0.0 7.19 MB

Here, we use Deep SHAP (or SHAP) to explain the behavior of nanophotonic structures learned by a convolutional neural network (CNN). Reference: https://pubs.acs.org/doi/full/10.1021/acsphotonics.0c01067

Python 100.00%

explainability_for_photonics's Introduction

Explainability for Photonics

Introduction

Welcome to the Raman Lab GitHub! This repo will walk you through the code used in the following publication: https://pubs.acs.org/doi/full/10.1021/acsphotonics.0c01067

Here, we use Deep SHAP (or SHAP) to explain the behavior of nanophotonic structures learned by a convolutional neural network (CNN).

Requirements

The following libraries are required to run the provided scripts. Specific versions are needed due to compatibility issues between Tensorflow and SHAP (as of this writing).

-Python 3.7.4

-Tensorflow 1.14.0

-SHAP 0.31.0

-OpenCV (CV2) 3.4.2

-Numpy 1.17.3

Installation and usage instructions for Deep SHAP are at: https://github.com/slundberg/shap

Steps

1) Train the CNN (CNN_Train.py)

Download the files in the 'Training Data' folder and update the following lines in the 'CNN_Train.py' file:

## Define File Locations (Images, Spectra, and CNN Model Save)
img_path = 'C:/.../*.png'
spectra_path = 'C:/.../Spectra.csv'
save_dir = 'C:/.../model.h5'

Running this file will train the CNN and save the model in the specified location. Depending on the available hardware, the CNN training process can take up to a few hours.

2) Explain CNN Results (SHAP_Explanation.py)

Deep SHAP explains the predictions of an 'Base' image in reference to a 'Background'. This Background can be a collection of images or a single image. To minimize noise, our recommendation is to use a 'white' image as the Base, and the image to be evaluated as the Background. This will compare the importance of a feature, to the absence of this feature, towards a target output. Simply update the following paths and run the 'SHAP_Explanation.py' script (you can refer to the Examples folder for sample Background and Base images):

## Define File Locations (CNN, Test Image, and Background Image)
model = load_model('C:/.../model.h5', compile=False)
back_img_path = 'C:/.../Background.png'
base_img_path = 'C:/.../Base.png'

After running the script, a list of SHAP value heatmaps (shap_values) will be generated. The size and order of this list reflects the CNN's outputs, and the resolution of the heatmaps are the same as the CNN input images. Therefore, to plot a specific heatmap (corresponding to a particular wavelength), simply index the list as such:

shap.image_plot(shap_values[i], back_img.reshape(1,40,40,1), show=False) #where 'i' is a value between 0 and the total list size

Optionally, for ease of viewing, the SHAP values can be normalized and replotted like so:

import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as colors

with open('C:/.../shap_explanations.data', 'rb') as filehandle:
    shap_values = pickle.load(filehandle)
    
X = np.arange(-20, 20, 1)
Y = np.arange(-20, 20, 1)
X, Y = np.meshgrid(X, Y)

maximum = np.max(shap_values)
minimum = -np.min(shap_values)

shap_i = shap_values[i][:][:][:][:] #where 'i' is a value between 0 and the total list size
shap_i[shap_i>0] = shap_i[shap_i>0] / maximum
shap_i[shap_i<0] = shap_i[shap_i<0] / minimum
shap_values_normalized = shap_i.squeeze()[::-1]

fig = plt.figure()
ax = fig.gca()
pcm = ax.pcolormesh(X, Y, shap_values_normalized, norm=colors.SymLogNorm(linthresh=0.01, linscale=1),cmap='bwr', vmin=-1, vmax = 1)
fig.colorbar(pcm)
ax.axis('off')

3) Explanation Validation (SHAP_Validation.py)

To validate that the explanations represent physical phenomena, we used the SHAP explanations to reconstruct the original image, which can either suppress or enhance an absorption spectrum. This reconstructed image can be imported directly into EM simulation software (e.g., Lumerical FDTD). Run the 'SHAP_Validation.py' script after specifying the location of the saved SHAP values:

#Import SHAP Values
with open('C:/.../shap_explanations.data', 'rb') as filehandle:
    shap_values = pickle.load(filehandle)

Tune the conversion settings by modifying the following line in the script:

if np.max(shap_values_convert) > shap_values_convert[i][j] > np.max(shap_values_convert)*0.05: #Convert Top 95% of Red Pixels        

Citation

If you find this repo helpful, or use any of the code you find here, please cite our work using the following:

C. Yeung, et al. Elucidating the Behavior of Nanophotonic Structures through Explainable Machine Learning Algorithms. ACS Photonics, 2020. 

explainability_for_photonics's People

Contributors

cyyeung1234 avatar

Stargazers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.